Fix importing habana-frameworks library conditionally

This commit is contained in:
PiotrBLL 2024-11-05 01:27:31 +01:00
parent 52062dd798
commit bee28658b9
5 changed files with 39 additions and 28 deletions

View File

@ -3,6 +3,7 @@ import pytest
import scipy.ndimage import scipy.ndimage
import torch import torch
from whisper.hpu_utils import get_x_hpu
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
sizes = [ sizes = [
@ -100,7 +101,7 @@ def test_median_filter_equivalence(shape):
@pytest.mark.parametrize("N, M", sizes) @pytest.mark.parametrize("N, M", sizes)
def test_dtw_hpu_equivalence(N: int, M: int): def test_dtw_hpu_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32) x_numpy = np.random.randn(N, M).astype(np.float32)
x_hpu = torch.from_numpy(x_numpy).to("hpu") x_hpu = get_x_hpu(x_numpy)
trace_cpu = dtw_cpu(x_numpy) trace_cpu = dtw_cpu(x_numpy)
trace_hpu = dtw_hpu(x_hpu) trace_hpu = dtw_hpu(x_hpu)

View File

@ -6,7 +6,6 @@ import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from habana_frameworks.torch.utils.library_loader import load_habana_module
from tqdm import tqdm from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
@ -15,11 +14,6 @@ from .model import ModelDimensions, Whisper
from .transcribe import transcribe from .transcribe import transcribe
from .version import __version__ from .version import __version__
load_habana_module() # important to load torch.hpu
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -153,7 +147,13 @@ def load_model(
with ( with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: ) as fp:
checkpoint = torch.load(fp, map_location=device) if device == "hpu":
"""If the device is HPU,
the model should be loaded on CPU first
and then moved to HPU."""
checkpoint = torch.load(fp, map_location="cpu")
else:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
@ -163,7 +163,12 @@ def load_model(
if alignment_heads is not None: if alignment_heads is not None:
model.set_alignment_heads(alignment_heads) model.set_alignment_heads(alignment_heads)
if torch.hpu.is_available() or device == "hpu": if device == "hpu":
return wrap_in_hpu_graph(model) from habana_frameworks.torch.utils.library_loader import load_habana_module
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
load_habana_module()
if torch.hpu.is_available():
return wrap_in_hpu_graph(model)
return model.to(device) return model.to(device)

View File

@ -1,15 +1,8 @@
import warnings
import torch import torch
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
def load_default_hpu() -> str: def get_x_hpu(x_numpy):
""" x_hpu = torch.from_numpy(x_numpy).to("hpu")
Load HPU if available, otherwise use CUDA or CPU. return x_hpu
"""
if not torch.hpu.is_available():
warnings.warn("HPU is not available; trying to use CUDA instead.")
return "cuda" if torch.cuda.is_available() else "cpu"
return "hpu"

View File

@ -175,8 +175,17 @@ def dtw_hpu(x, BLOCK_SIZE=1024):
def dtw(x: torch.Tensor) -> np.ndarray: def dtw(x: torch.Tensor) -> np.ndarray:
if torch.hpu.is_available(): try:
return dtw_hpu(x) from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if torch.hpu.is_available():
return dtw_hpu(x)
except (ImportError, subprocess.CalledProcessError):
warnings.warn(
"Failed to import Habana modules, likely due to missing Habana libraries; "
)
if x.is_cuda: if x.is_cuda:
try: try:
return dtw_cuda(x) return dtw_cuda(x)

View File

@ -18,7 +18,6 @@ from .audio import (
pad_or_trim, pad_or_trim,
) )
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .hpu_utils import load_default_hpu
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
@ -515,8 +514,12 @@ def cli():
if device_name == "cuda" and not torch.cuda.is_available(): if device_name == "cuda" and not torch.cuda.is_available():
warnings.warn("CUDA is not available; using CPU instead") warnings.warn("CUDA is not available; using CPU instead")
device_name = "cpu" device_name = "cpu"
if device_name == "hpu" and not torch.hpu.is_available(): if device_name == "hpu":
warnings.warn("HPU is not available; using CPU instead") from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if not torch.hpu.is_available():
warnings.warn("HPU is not available; using CPU instead")
device_name = "cpu" device_name = "cpu"
return device_name return device_name
@ -526,7 +529,7 @@ def cli():
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")