mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Fix importing habana-frameworks library conditionally
This commit is contained in:
parent
52062dd798
commit
bee28658b9
@ -3,6 +3,7 @@ import pytest
|
|||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from whisper.hpu_utils import get_x_hpu
|
||||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
|
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
|
||||||
|
|
||||||
sizes = [
|
sizes = [
|
||||||
@ -100,7 +101,7 @@ def test_median_filter_equivalence(shape):
|
|||||||
@pytest.mark.parametrize("N, M", sizes)
|
@pytest.mark.parametrize("N, M", sizes)
|
||||||
def test_dtw_hpu_equivalence(N: int, M: int):
|
def test_dtw_hpu_equivalence(N: int, M: int):
|
||||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||||
x_hpu = torch.from_numpy(x_numpy).to("hpu")
|
x_hpu = get_x_hpu(x_numpy)
|
||||||
|
|
||||||
trace_cpu = dtw_cpu(x_numpy)
|
trace_cpu = dtw_cpu(x_numpy)
|
||||||
trace_hpu = dtw_hpu(x_hpu)
|
trace_hpu = dtw_hpu(x_hpu)
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import warnings
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
@ -15,11 +14,6 @@ from .model import ModelDimensions, Whisper
|
|||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
load_habana_module() # important to load torch.hpu
|
|
||||||
|
|
||||||
if torch.hpu.is_available():
|
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
@ -153,6 +147,12 @@ def load_model(
|
|||||||
with (
|
with (
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
|
if device == "hpu":
|
||||||
|
"""If the device is HPU,
|
||||||
|
the model should be loaded on CPU first
|
||||||
|
and then moved to HPU."""
|
||||||
|
checkpoint = torch.load(fp, map_location="cpu")
|
||||||
|
else:
|
||||||
checkpoint = torch.load(fp, map_location=device)
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
@ -163,7 +163,12 @@ def load_model(
|
|||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
model.set_alignment_heads(alignment_heads)
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
if torch.hpu.is_available() or device == "hpu":
|
if device == "hpu":
|
||||||
|
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||||
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
|
||||||
|
load_habana_module()
|
||||||
|
if torch.hpu.is_available():
|
||||||
return wrap_in_hpu_graph(model)
|
return wrap_in_hpu_graph(model)
|
||||||
|
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|||||||
@ -1,15 +1,8 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||||
|
|
||||||
|
load_habana_module()
|
||||||
|
|
||||||
def load_default_hpu() -> str:
|
def get_x_hpu(x_numpy):
|
||||||
"""
|
x_hpu = torch.from_numpy(x_numpy).to("hpu")
|
||||||
Load HPU if available, otherwise use CUDA or CPU.
|
return x_hpu
|
||||||
"""
|
|
||||||
|
|
||||||
if not torch.hpu.is_available():
|
|
||||||
warnings.warn("HPU is not available; trying to use CUDA instead.")
|
|
||||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
return "hpu"
|
|
||||||
|
|||||||
@ -175,8 +175,17 @@ def dtw_hpu(x, BLOCK_SIZE=1024):
|
|||||||
|
|
||||||
|
|
||||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||||
|
load_habana_module()
|
||||||
|
|
||||||
if torch.hpu.is_available():
|
if torch.hpu.is_available():
|
||||||
return dtw_hpu(x)
|
return dtw_hpu(x)
|
||||||
|
except (ImportError, subprocess.CalledProcessError):
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to import Habana modules, likely due to missing Habana libraries; "
|
||||||
|
)
|
||||||
|
|
||||||
if x.is_cuda:
|
if x.is_cuda:
|
||||||
try:
|
try:
|
||||||
return dtw_cuda(x)
|
return dtw_cuda(x)
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from .audio import (
|
|||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
)
|
)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .hpu_utils import load_default_hpu
|
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -515,7 +514,11 @@ def cli():
|
|||||||
if device_name == "cuda" and not torch.cuda.is_available():
|
if device_name == "cuda" and not torch.cuda.is_available():
|
||||||
warnings.warn("CUDA is not available; using CPU instead")
|
warnings.warn("CUDA is not available; using CPU instead")
|
||||||
device_name = "cpu"
|
device_name = "cpu"
|
||||||
if device_name == "hpu" and not torch.hpu.is_available():
|
if device_name == "hpu":
|
||||||
|
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||||
|
|
||||||
|
load_habana_module()
|
||||||
|
if not torch.hpu.is_available():
|
||||||
warnings.warn("HPU is not available; using CPU instead")
|
warnings.warn("HPU is not available; using CPU instead")
|
||||||
device_name = "cpu"
|
device_name = "cpu"
|
||||||
|
|
||||||
@ -526,7 +529,7 @@ def cli():
|
|||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user