diff --git a/whisper/__init__.py b/whisper/__init__.py index f284ec0..24a1be2 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -128,7 +128,12 @@ def load_model( """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc36..6b87b59 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -128,6 +128,8 @@ def transcribe( if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + warnings.warn("Performing inference on CPU when MPS is available") if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32