MPS (Mac acceleration) by default if available

This commit is contained in:
Dwarkesh Patel 2022-10-20 19:46:19 -05:00
parent 9f70a352f9
commit bdd0d79b8e
2 changed files with 8 additions and 1 deletions

View File

@ -88,7 +88,12 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
""" """
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
if download_root is None: if download_root is None:
download_root = os.getenv( download_root = os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",

View File

@ -74,6 +74,8 @@ def transcribe(
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available") warnings.warn("Performing inference on CPU when CUDA is available")
if torch.backends.mps.is_available():
warnings.warn("Performing inference on CPU when MPS is available")
if dtype == torch.float16: if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead") warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32 dtype = torch.float32