diff --git a/tests/test_timing.py b/tests/test_timing.py index 9bab838..58a2812 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -3,7 +3,7 @@ import pytest import scipy.ndimage import torch -from whisper.timing import dtw_cpu, dtw_cuda, median_filter +from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu sizes = [ (10, 20), @@ -94,3 +94,15 @@ def test_median_filter_equivalence(shape): filtered_gpu = median_filter(x.cuda(), filter_width).cpu() assert np.allclose(filtered_cpu, filtered_gpu) + + +@pytest.mark.requires_hpu +@pytest.mark.parametrize("N, M", sizes) +def test_dtw_hpu_equivalence(N: int, M: int): + x_numpy = np.random.randn(N, M).astype(np.float32) + x_hpu = torch.from_numpy(x_numpy).to("hpu") + + trace_cpu = dtw_cpu(x_numpy) + trace_hpu = dtw_hpu(x_hpu) + + assert np.allclose(trace_cpu, trace_hpu) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 599221a..49be253 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -10,7 +10,7 @@ from whisper.tokenizer import get_tokenizer @pytest.mark.parametrize("model_name", whisper.available_models()) def test_transcribe(model_name: str): device = "cuda" if torch.cuda.is_available() else "cpu" - model = whisper.load_model(model_name).to(device) + model = whisper.load_model(model_name, device=device) audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") language = "en" if model_name.endswith(".en") else None diff --git a/whisper/timing.py b/whisper/timing.py index e563414..af58c39 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -138,7 +138,45 @@ def dtw_cuda(x, BLOCK_SIZE=1024): return backtrace(trace.cpu().numpy()) +def dtw_hpu(x, BLOCK_SIZE=1024): + """ + DTW implementation for HPU. + """ + M, N = x.shape + assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" + + x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + x_skew = x_skew.T.contiguous() + + # Initialize cost and trace matrices with high values for comparison + cost = torch.ones(N + M + 2, M + 2, device="hpu") * np.inf + cost[0, 0] = 0 # Start point for DTW + trace = torch.zeros_like(cost, dtype=torch.int32, device="hpu") + + for k in range(1, N + M + 1): + p0 = cost[k - 1, :M] + p1 = cost[k, :M] + p2 = cost[k, 1:M + 1] + + c0 = p0.clone() + c1 = p1.clone() + c2 = p2.clone() + + x_row = x_skew[k - 1, :M] + + cost_row = x_row + torch.min(torch.min(c0, c1), c2) + cost[k + 1, 1:M + 1] = cost_row + + # Track path by storing traces + trace[k + 1, 1:M + 1] = 2 * (c2 <= c0) * (c2 <= c1) + 1 * (c1 <= c0) * (c1 <= c2) + 0 * (c0 <= c1) * (c0 <= c2) + + trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1] + return backtrace(trace.cpu().numpy()) + + def dtw(x: torch.Tensor) -> np.ndarray: + if torch.hpu.is_available(): + return dtw_hpu(x) if x.is_cuda: try: return dtw_cuda(x) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8e1240b..f3ceffd 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -18,6 +18,7 @@ from .audio import ( pad_or_trim, ) from .decoding import DecodingOptions, DecodingResult +from .hpu_utils import load_default_hpu from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import ( @@ -125,6 +126,8 @@ def transcribe( if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 + if model.device == torch.device("hpu") and torch.hpu.is_available(): + warnings.warn("Performing inference on HPU when CUDA is available") if dtype == torch.float32: decode_options["fp16"] = False @@ -508,12 +511,22 @@ def cli(): f"model should be one of {available_models()} or path to a model checkpoint" ) + def valid_device(device_name): + if device_name == "cuda" and not torch.cuda.is_available(): + warnings.warn("CUDA is not available; using CPU instead") + device_name = "cpu" + if device_name == "hpu" and not torch.hpu.is_available(): + warnings.warn("HPU is not available; using CPU instead") + device_name = "cpu" + + return device_name + # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")