From e0f0221a924755805254501c6aad7075db1876d1 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Wed, 5 Jul 2023 22:27:48 +0530 Subject: [PATCH] added parser argument to support MPS by default --- whisper/transcribe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6e43a22..c2e98b1 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -378,12 +378,20 @@ def transcribe( def cli(): from . import available_models + if torch.cuda.is_available(): + default_device = "cuda" + elif hasattr(torch.backends, 'mps'): + if torch.backends.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" + # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default=default_device, help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")