From bdd0d79b8e533a794e39ed525f59118c0c54513f Mon Sep 17 00:00:00 2001 From: Dwarkesh Patel Date: Thu, 20 Oct 2022 19:46:19 -0500 Subject: [PATCH 1/3] MPS (Mac acceleration) by default if available --- whisper/__init__.py | 7 ++++++- whisper/transcribe.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index 9fbcc79..d08d719 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -88,7 +88,12 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" if download_root is None: download_root = os.getenv( "XDG_CACHE_HOME", diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 654f7b4..8406395 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -74,6 +74,8 @@ def transcribe( if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") + if torch.backends.mps.is_available(): + warnings.warn("Performing inference on CPU when MPS is available") if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 From 4b77a81c1fa29fcbcf470d16a4fdd7d09cb0fd3a Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Wed, 18 Jan 2023 14:39:07 -0800 Subject: [PATCH 2/3] hasattr check for torch.backends.mps --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6b5993f..0d89a5d 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -75,7 +75,7 @@ def transcribe( if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") - if torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): warnings.warn("Performing inference on CPU when MPS is available") if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") From 51c785f7c91b8c032a1fa79c0e8f862dea81b860 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Wed, 18 Jan 2023 14:44:02 -0800 Subject: [PATCH 3/3] add another hasattr check for torch.backends.mps --- whisper/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index 080c303..4f45f99 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -94,7 +94,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow if device is None: if torch.cuda.is_available(): device = "cuda" - elif torch.backends.mps.is_available(): + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" else: device = "cpu"