mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge d15213d5616a6098732c5094b49c20d5e8a32f49 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
194f12c3bc
@ -4,8 +4,11 @@ import os
|
|||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
if find_spec("intel_extension_for_pytorch") is not None:
|
||||||
|
import intel_extension_for_pytorch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
@ -128,7 +131,13 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
|
||||||
|
device = "xpu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
if download_root is None:
|
if download_root is None:
|
||||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -128,6 +129,8 @@ def transcribe(
|
|||||||
if model.device == torch.device("cpu"):
|
if model.device == torch.device("cpu"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||||
|
if find_spec('torch.xpu') is not None and torch.xpu.is_available():
|
||||||
|
warnings.warn("Performing inference on CPU when XPU is available")
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
@ -529,7 +532,7 @@ def cli():
|
|||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user