mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge d15213d5616a6098732c5094b49c20d5e8a32f49 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
194f12c3bc
@ -4,8 +4,11 @@ import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from importlib.util import find_spec
|
||||
|
||||
import torch
|
||||
if find_spec("intel_extension_for_pytorch") is not None:
|
||||
import intel_extension_for_pytorch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
@ -128,7 +131,13 @@ def load_model(
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
|
||||
device = "xpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from importlib.util import find_spec
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -128,6 +129,8 @@ def transcribe(
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if find_spec('torch.xpu') is not None and torch.xpu.is_available():
|
||||
warnings.warn("Performing inference on CPU when XPU is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
@ -529,7 +532,7 @@ def cli():
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user