Add HPU support in transcribe and timing + tests

This commit is contained in:
PiotrBLL 2024-10-31 13:56:46 +01:00
parent 23651574df
commit 9269b2ac35
4 changed files with 66 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import pytest
import scipy.ndimage
import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
sizes = [
(10, 20),
@ -94,3 +94,15 @@ def test_median_filter_equivalence(shape):
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
assert np.allclose(filtered_cpu, filtered_gpu)
@pytest.mark.requires_hpu
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_hpu_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_hpu = torch.from_numpy(x_numpy).to("hpu")
trace_cpu = dtw_cpu(x_numpy)
trace_hpu = dtw_hpu(x_hpu)
assert np.allclose(trace_cpu, trace_hpu)

View File

@ -10,7 +10,7 @@ from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
model = whisper.load_model(model_name, device=device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None

View File

@ -138,7 +138,45 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
return backtrace(trace.cpu().numpy())
def dtw_hpu(x, BLOCK_SIZE=1024):
"""
DTW implementation for HPU.
"""
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = x_skew.T.contiguous()
# Initialize cost and trace matrices with high values for comparison
cost = torch.ones(N + M + 2, M + 2, device="hpu") * np.inf
cost[0, 0] = 0 # Start point for DTW
trace = torch.zeros_like(cost, dtype=torch.int32, device="hpu")
for k in range(1, N + M + 1):
p0 = cost[k - 1, :M]
p1 = cost[k, :M]
p2 = cost[k, 1:M + 1]
c0 = p0.clone()
c1 = p1.clone()
c2 = p2.clone()
x_row = x_skew[k - 1, :M]
cost_row = x_row + torch.min(torch.min(c0, c1), c2)
cost[k + 1, 1:M + 1] = cost_row
# Track path by storing traces
trace[k + 1, 1:M + 1] = 2 * (c2 <= c0) * (c2 <= c1) + 1 * (c1 <= c0) * (c1 <= c2) + 0 * (c0 <= c1) * (c0 <= c2)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if torch.hpu.is_available():
return dtw_hpu(x)
if x.is_cuda:
try:
return dtw_cuda(x)

View File

@ -18,6 +18,7 @@ from .audio import (
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .hpu_utils import load_default_hpu
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (
@ -125,6 +126,8 @@ def transcribe(
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if model.device == torch.device("hpu") and torch.hpu.is_available():
warnings.warn("Performing inference on HPU when CUDA is available")
if dtype == torch.float32:
decode_options["fp16"] = False
@ -508,12 +511,22 @@ def cli():
f"model should be one of {available_models()} or path to a model checkpoint"
)
def valid_device(device_name):
if device_name == "cuda" and not torch.cuda.is_available():
warnings.warn("CUDA is not available; using CPU instead")
device_name = "cpu"
if device_name == "hpu" and not torch.hpu.is_available():
warnings.warn("HPU is not available; using CPU instead")
device_name = "cpu"
return device_name
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")