Add support for Intel GPU's

Requires Intel Extension for PyTorch v1.13.120+xpu
https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html

Tested on Intel ARC A770 16GB VRAM with large model
This commit is contained in:
leuc 2023-05-18 17:10:33 +02:00
parent 248b6cb124
commit 299a658f5e
2 changed files with 14 additions and 2 deletions

View File

@ -4,8 +4,11 @@ import os
import urllib
import warnings
from typing import List, Optional, Union
from importlib.util import find_spec
import torch
if find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
@ -122,7 +125,13 @@ def load_model(
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

View File

@ -2,6 +2,7 @@ import argparse
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from importlib.util import find_spec
import numpy as np
import torch
@ -110,6 +111,8 @@ def transcribe(
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if find_spec('torch.xpu') is not None and torch.xpu.is_available():
warnings.warn("Performing inference on CPU when XPU is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
@ -379,7 +382,7 @@ def cli():
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")