Merge d15213d5616a6098732c5094b49c20d5e8a32f49 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
leuc 2025-11-16 00:53:19 +01:00 committed by GitHub
commit 194f12c3bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -4,8 +4,11 @@ import os
import urllib import urllib
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
from importlib.util import find_spec
import torch import torch
if find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch
from tqdm import tqdm from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
@ -128,7 +131,13 @@ def load_model(
""" """
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available():
device = "cuda"
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
if download_root is None: if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache") default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

View File

@ -3,6 +3,7 @@ import os
import traceback import traceback
import warnings import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from importlib.util import find_spec
import numpy as np import numpy as np
import torch import torch
@ -128,6 +129,8 @@ def transcribe(
if model.device == torch.device("cpu"): if model.device == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available") warnings.warn("Performing inference on CPU when CUDA is available")
if find_spec('torch.xpu') is not None and torch.xpu.is_available():
warnings.warn("Performing inference on CPU when XPU is available")
if dtype == torch.float16: if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead") warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32 dtype = torch.float32
@ -529,7 +532,7 @@ def cli():
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")