mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Add initial imports and usage of wrap_in_hpu_graph
This commit is contained in:
parent
25639fc17d
commit
b7069e579d
@ -6,6 +6,7 @@ import warnings
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
@ -14,6 +15,11 @@ from .model import ModelDimensions, Whisper
|
|||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
|
load_habana_module() # important to load torch.hpu
|
||||||
|
|
||||||
|
if torch.hpu.is_available():
|
||||||
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
@ -157,4 +163,7 @@ def load_model(
|
|||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
model.set_alignment_heads(alignment_heads)
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
|
if torch.hpu.is_available() or device == "hpu":
|
||||||
|
return wrap_in_hpu_graph(model)
|
||||||
|
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user