diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718..7bfdbfd 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -6,6 +6,7 @@ import warnings from typing import List, Optional, Union import torch +from habana_frameworks.torch.utils.library_loader import load_habana_module from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim @@ -14,6 +15,11 @@ from .model import ModelDimensions, Whisper from .transcribe import transcribe from .version import __version__ +load_habana_module() # important to load torch.hpu + +if torch.hpu.is_available(): + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", @@ -157,4 +163,7 @@ def load_model( if alignment_heads is not None: model.set_alignment_heads(alignment_heads) + if torch.hpu.is_available() or device == "hpu": + return wrap_in_hpu_graph(model) + return model.to(device)