From bee28658b9a5812e83bfbcb51032c0793a1d04f6 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Tue, 5 Nov 2024 01:27:31 +0100 Subject: [PATCH] Fix importing habana-frameworks library conditionally --- tests/test_timing.py | 3 ++- whisper/__init__.py | 23 ++++++++++++++--------- whisper/hpu_utils.py | 17 +++++------------ whisper/timing.py | 13 +++++++++++-- whisper/transcribe.py | 11 +++++++---- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/test_timing.py b/tests/test_timing.py index 58a2812..47ced8c 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -3,6 +3,7 @@ import pytest import scipy.ndimage import torch +from whisper.hpu_utils import get_x_hpu from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu sizes = [ @@ -100,7 +101,7 @@ def test_median_filter_equivalence(shape): @pytest.mark.parametrize("N, M", sizes) def test_dtw_hpu_equivalence(N: int, M: int): x_numpy = np.random.randn(N, M).astype(np.float32) - x_hpu = torch.from_numpy(x_numpy).to("hpu") + x_hpu = get_x_hpu(x_numpy) trace_cpu = dtw_cpu(x_numpy) trace_hpu = dtw_hpu(x_hpu) diff --git a/whisper/__init__.py b/whisper/__init__.py index 7bfdbfd..e66e78e 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -6,7 +6,6 @@ import warnings from typing import List, Optional, Union import torch -from habana_frameworks.torch.utils.library_loader import load_habana_module from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim @@ -15,11 +14,6 @@ from .model import ModelDimensions, Whisper from .transcribe import transcribe from .version import __version__ -load_habana_module() # important to load torch.hpu - -if torch.hpu.is_available(): - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", @@ -153,7 +147,13 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - checkpoint = torch.load(fp, map_location=device) + if device == "hpu": + """If the device is HPU, + the model should be loaded on CPU first + and then moved to HPU.""" + checkpoint = torch.load(fp, map_location="cpu") + else: + checkpoint = torch.load(fp, map_location=device) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) @@ -163,7 +163,12 @@ def load_model( if alignment_heads is not None: model.set_alignment_heads(alignment_heads) - if torch.hpu.is_available() or device == "hpu": - return wrap_in_hpu_graph(model) + if device == "hpu": + from habana_frameworks.torch.utils.library_loader import load_habana_module + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + load_habana_module() + if torch.hpu.is_available(): + return wrap_in_hpu_graph(model) return model.to(device) diff --git a/whisper/hpu_utils.py b/whisper/hpu_utils.py index 7a66aa5..5ec464a 100644 --- a/whisper/hpu_utils.py +++ b/whisper/hpu_utils.py @@ -1,15 +1,8 @@ -import warnings - import torch +from habana_frameworks.torch.utils.library_loader import load_habana_module +load_habana_module() -def load_default_hpu() -> str: - """ - Load HPU if available, otherwise use CUDA or CPU. - """ - - if not torch.hpu.is_available(): - warnings.warn("HPU is not available; trying to use CUDA instead.") - return "cuda" if torch.cuda.is_available() else "cpu" - - return "hpu" +def get_x_hpu(x_numpy): + x_hpu = torch.from_numpy(x_numpy).to("hpu") + return x_hpu diff --git a/whisper/timing.py b/whisper/timing.py index af58c39..dec2081 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -175,8 +175,17 @@ def dtw_hpu(x, BLOCK_SIZE=1024): def dtw(x: torch.Tensor) -> np.ndarray: - if torch.hpu.is_available(): - return dtw_hpu(x) + try: + from habana_frameworks.torch.utils.library_loader import load_habana_module + load_habana_module() + + if torch.hpu.is_available(): + return dtw_hpu(x) + except (ImportError, subprocess.CalledProcessError): + warnings.warn( + "Failed to import Habana modules, likely due to missing Habana libraries; " + ) + if x.is_cuda: try: return dtw_cuda(x) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index f3ceffd..9110afb 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -18,7 +18,6 @@ from .audio import ( pad_or_trim, ) from .decoding import DecodingOptions, DecodingResult -from .hpu_utils import load_default_hpu from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import ( @@ -515,8 +514,12 @@ def cli(): if device_name == "cuda" and not torch.cuda.is_available(): warnings.warn("CUDA is not available; using CPU instead") device_name = "cpu" - if device_name == "hpu" and not torch.hpu.is_available(): - warnings.warn("HPU is not available; using CPU instead") + if device_name == "hpu": + from habana_frameworks.torch.utils.library_loader import load_habana_module + + load_habana_module() + if not torch.hpu.is_available(): + warnings.warn("HPU is not available; using CPU instead") device_name = "cpu" return device_name @@ -526,7 +529,7 @@ def cli(): parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")