Add initial imports and usage of wrap_in_hpu_graph

This commit is contained in:
PiotrBLL 2024-10-31 13:54:29 +01:00
parent 25639fc17d
commit b7069e579d

View File

@ -6,6 +6,7 @@ import warnings
from typing import List, Optional, Union
import torch
from habana_frameworks.torch.utils.library_loader import load_habana_module
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
@ -14,6 +15,11 @@ from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
load_habana_module() # important to load torch.hpu
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -157,4 +163,7 @@ def load_model(
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
if torch.hpu.is_available() or device == "hpu":
return wrap_in_hpu_graph(model)
return model.to(device)