mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Add HPU support in transcribe and timing + tests
This commit is contained in:
parent
23651574df
commit
9269b2ac35
@ -3,7 +3,7 @@ import pytest
|
|||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
|
||||||
|
|
||||||
sizes = [
|
sizes = [
|
||||||
(10, 20),
|
(10, 20),
|
||||||
@ -94,3 +94,15 @@ def test_median_filter_equivalence(shape):
|
|||||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||||
|
|
||||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires_hpu
|
||||||
|
@pytest.mark.parametrize("N, M", sizes)
|
||||||
|
def test_dtw_hpu_equivalence(N: int, M: int):
|
||||||
|
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||||
|
x_hpu = torch.from_numpy(x_numpy).to("hpu")
|
||||||
|
|
||||||
|
trace_cpu = dtw_cpu(x_numpy)
|
||||||
|
trace_hpu = dtw_hpu(x_hpu)
|
||||||
|
|
||||||
|
assert np.allclose(trace_cpu, trace_hpu)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from whisper.tokenizer import get_tokenizer
|
|||||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
def test_transcribe(model_name: str):
|
def test_transcribe(model_name: str):
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = whisper.load_model(model_name).to(device)
|
model = whisper.load_model(model_name, device=device)
|
||||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||||
|
|
||||||
language = "en" if model_name.endswith(".en") else None
|
language = "en" if model_name.endswith(".en") else None
|
||||||
|
|||||||
@ -138,7 +138,45 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
|||||||
return backtrace(trace.cpu().numpy())
|
return backtrace(trace.cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def dtw_hpu(x, BLOCK_SIZE=1024):
|
||||||
|
"""
|
||||||
|
DTW implementation for HPU.
|
||||||
|
"""
|
||||||
|
M, N = x.shape
|
||||||
|
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||||
|
|
||||||
|
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||||
|
x_skew = x_skew.T.contiguous()
|
||||||
|
|
||||||
|
# Initialize cost and trace matrices with high values for comparison
|
||||||
|
cost = torch.ones(N + M + 2, M + 2, device="hpu") * np.inf
|
||||||
|
cost[0, 0] = 0 # Start point for DTW
|
||||||
|
trace = torch.zeros_like(cost, dtype=torch.int32, device="hpu")
|
||||||
|
|
||||||
|
for k in range(1, N + M + 1):
|
||||||
|
p0 = cost[k - 1, :M]
|
||||||
|
p1 = cost[k, :M]
|
||||||
|
p2 = cost[k, 1:M + 1]
|
||||||
|
|
||||||
|
c0 = p0.clone()
|
||||||
|
c1 = p1.clone()
|
||||||
|
c2 = p2.clone()
|
||||||
|
|
||||||
|
x_row = x_skew[k - 1, :M]
|
||||||
|
|
||||||
|
cost_row = x_row + torch.min(torch.min(c0, c1), c2)
|
||||||
|
cost[k + 1, 1:M + 1] = cost_row
|
||||||
|
|
||||||
|
# Track path by storing traces
|
||||||
|
trace[k + 1, 1:M + 1] = 2 * (c2 <= c0) * (c2 <= c1) + 1 * (c1 <= c0) * (c1 <= c2) + 0 * (c0 <= c1) * (c0 <= c2)
|
||||||
|
|
||||||
|
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
|
||||||
|
return backtrace(trace.cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||||
|
if torch.hpu.is_available():
|
||||||
|
return dtw_hpu(x)
|
||||||
if x.is_cuda:
|
if x.is_cuda:
|
||||||
try:
|
try:
|
||||||
return dtw_cuda(x)
|
return dtw_cuda(x)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from .audio import (
|
|||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
)
|
)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
|
from .hpu_utils import load_default_hpu
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -125,6 +126,8 @@ def transcribe(
|
|||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
if model.device == torch.device("hpu") and torch.hpu.is_available():
|
||||||
|
warnings.warn("Performing inference on HPU when CUDA is available")
|
||||||
|
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
@ -508,12 +511,22 @@ def cli():
|
|||||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def valid_device(device_name):
|
||||||
|
if device_name == "cuda" and not torch.cuda.is_available():
|
||||||
|
warnings.warn("CUDA is not available; using CPU instead")
|
||||||
|
device_name = "cpu"
|
||||||
|
if device_name == "hpu" and not torch.hpu.is_available():
|
||||||
|
warnings.warn("HPU is not available; using CPU instead")
|
||||||
|
device_name = "cpu"
|
||||||
|
|
||||||
|
return device_name
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default=load_default_hpu(), type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user