mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge pull request #2 from BlueLabelLabs/feat/Add-HPU-support
Add hpu support
This commit is contained in:
commit
1a42019974
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
.graph_dumps
|
||||
34
Dockerfile.hpu
Normal file
34
Dockerfile.hpu
Normal file
@ -0,0 +1,34 @@
|
||||
# Use the official Gaudi Docker image with PyTorch
|
||||
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
|
||||
|
||||
# Set environment variables for Habana
|
||||
ENV HABANA_VISIBLE_DEVICES=all
|
||||
ENV OMPI_MCA_btl_vader_single_copy_mechanism=none
|
||||
ENV PT_HPU_LAZY_ACC_PAR_MODE=0
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||
|
||||
# Set timezone to UTC and install essential packages
|
||||
ENV DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC
|
||||
RUN apt-get update && apt-get install -y \
|
||||
tzdata \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Download and install the static build of ffmpeg
|
||||
RUN mkdir -p /usr/local/bin/ffmpeg && \
|
||||
cd /usr/local/bin/ffmpeg && \
|
||||
wget https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz && \
|
||||
tar -xf ffmpeg-release-amd64-static.tar.xz && \
|
||||
cp -a ffmpeg-*-static/ffmpeg /usr/bin/ffmpeg && \
|
||||
cp -a ffmpeg-*-static/ffprobe /usr/bin/ffprobe && \
|
||||
rm -rf /usr/local/bin/ffmpeg
|
||||
|
||||
COPY . /workspace/whisper
|
||||
WORKDIR /workspace/whisper
|
||||
|
||||
# Copy HPU requirements
|
||||
COPY requirements_hpu.txt /workspace/requirements_hpu.txt
|
||||
|
||||
# Install Python packages
|
||||
RUN pip install --upgrade pip \
|
||||
&& pip install -r requirements_hpu.txt
|
||||
59
README.md
59
README.md
@ -93,6 +93,10 @@ Adding `--task translate` will translate the speech into English:
|
||||
|
||||
whisper japanese.wav --language Japanese --task translate
|
||||
|
||||
The following command will transcribe speech in audio files, using the Intel® Gaudi® HPU (`--device hpu` option):
|
||||
|
||||
whisper audio.flac audio.mp3 audio.wav --model turbo --device hpu
|
||||
|
||||
Run the following to view all available options:
|
||||
|
||||
whisper --help
|
||||
@ -140,6 +144,61 @@ result = whisper.decode(model, mel, options)
|
||||
print(result.text)
|
||||
```
|
||||
|
||||
## Intel® Gaudi® hpu usage
|
||||
|
||||
### Build the Docker Image
|
||||
|
||||
```bash
|
||||
docker build -t whisper_hpu:latest -f Dockerfile.hpu .
|
||||
```
|
||||
|
||||
In the `Dockerfile.hpu`, we use the `vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest` base image, make sure to replace it with the appropriate version for your environment if needed.
|
||||
See the [PyTorch Docker Images for the Intel® Gaudi® Accelerator](https://developer.habana.ai/catalog/pytorch-container/) for more information.
|
||||
|
||||
### Run the Container
|
||||
|
||||
```bash
|
||||
docker run -it --runtime=habana whisper_hpu:latest
|
||||
```
|
||||
|
||||
Using a mapping volume (`-v`) is optional, but it allows you to access the Whisper repository from within the container.
|
||||
You can make this by adding `-v /path/to/your/whisper:/workspace/whisper` to the `docker run` command.
|
||||
If you decide to use the mapping make sure to replace `/path/to/your/whisper` with the path to the Whisper repository on your local machine.
|
||||
|
||||
### Command-line usage with Intel® Gaudi® hpu
|
||||
|
||||
To run the `transcribe` process with Intel® Gaudi® HPU, you can use the `--device hpu` option:
|
||||
|
||||
```bash
|
||||
python3 -m whisper.transcribe audio_file.wav --model turbo --device hpu
|
||||
```
|
||||
|
||||
* Note: Change `audio_file.wav` to the path of the audio file you want to transcribe. (Example file: https://www.kaggle.com/datasets/pavanelisetty/sample-audio-files-for-speech-recognition?resource=download)
|
||||
|
||||
To run the `transcribe` tests with Intel® Gaudi® HPU, make sure to install the `pytest` package:
|
||||
|
||||
```bash
|
||||
pip install pytest
|
||||
```
|
||||
|
||||
and run the following command:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=. pytest -s tests/test_transcribe.py::test_transcribe_hpu
|
||||
```
|
||||
|
||||
### Python usage with Intel® Gaudi® hpu
|
||||
|
||||
To use Intel® Gaudi® hpu within Python, you can specify the device when loading the model:
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("turbo", device="hpu")
|
||||
result = model.transcribe("audio.mp3")
|
||||
print(result["text"])
|
||||
```
|
||||
|
||||
## More examples
|
||||
|
||||
Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc.
|
||||
|
||||
36
notebooks/LibriSpeech.ipynb
generated
36
notebooks/LibriSpeech.ipynb
generated
@ -36,10 +36,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "3CqtR2Fi5-vP"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
@ -56,10 +53,35 @@
|
||||
"import torchaudio\n",
|
||||
"\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Set the `DEVICE` based on available hardware\n",
|
||||
"### If you're running on a machine with Intel Gaudi support, you can set the DEVICE to \"hpu\"\n",
|
||||
"### This can be done by setting the device to \"hpu\" if Gaudi is available.\n",
|
||||
"### Ensure you have the Habana PyTorch extension installed and properly configured."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# DEVICE = \"hpu\"\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
||||
30
notebooks/Multilingual_ASR.ipynb
generated
30
notebooks/Multilingual_ASR.ipynb
generated
@ -51,10 +51,33 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"pd.options.display.max_rows = 100\n",
|
||||
"pd.options.display.max_colwidth = 1000\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
"pd.options.display.max_colwidth = 1000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Set the `DEVICE` based on available hardware\n",
|
||||
"### If you're running on a machine with Intel Gaudi support, you can set the DEVICE to \"hpu\"\n",
|
||||
"### This can be done by setting the device to \"hpu\" if Gaudi is available.\n",
|
||||
"### Ensure you have the Habana PyTorch extension installed and properly configured."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# DEVICE = \"hpu\"\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@ -860,8 +883,7 @@
|
||||
"text": [
|
||||
"Importing the dtw module. When using in academic works please cite:\n",
|
||||
" T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
|
||||
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
|
||||
"\n"
|
||||
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@ -5,3 +5,5 @@ tqdm
|
||||
more-itertools
|
||||
tiktoken
|
||||
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
||||
scipy
|
||||
pytest
|
||||
|
||||
6
requirements_hpu.txt
Normal file
6
requirements_hpu.txt
Normal file
@ -0,0 +1,6 @@
|
||||
optimum-habana==1.14.1
|
||||
transformers==4.45.2
|
||||
huggingface-hub==0.26.2
|
||||
tiktoken==0.8.0
|
||||
torch-geometric==2.6.1
|
||||
numba==0.60.0
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "requires_cuda")
|
||||
config.addinivalue_line("markers", "requires_hpu")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -3,7 +3,8 @@ import pytest
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
||||
from whisper.hpu_utils import get_x_hpu
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter, dtw_hpu
|
||||
|
||||
sizes = [
|
||||
(10, 20),
|
||||
@ -94,3 +95,15 @@ def test_median_filter_equivalence(shape):
|
||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||
|
||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||
|
||||
|
||||
@pytest.mark.requires_hpu
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw_hpu_equivalence(N: int, M: int):
|
||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||
x_hpu = get_x_hpu(x_numpy)
|
||||
|
||||
trace_cpu = dtw_cpu(x_numpy)
|
||||
trace_hpu = dtw_hpu(x_hpu)
|
||||
|
||||
assert np.allclose(trace_cpu, trace_hpu)
|
||||
|
||||
@ -10,7 +10,7 @@ from whisper.tokenizer import get_tokenizer
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
def test_transcribe(model_name: str):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(model_name).to(device)
|
||||
model = whisper.load_model(model_name, device=device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
@ -40,3 +40,38 @@ def test_transcribe(model_name: str):
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
|
||||
@pytest.mark.requires_hpu
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
def test_transcribe_hpu(model_name: str):
|
||||
device = "hpu"
|
||||
model = whisper.load_model(model_name, device=device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||
)
|
||||
assert result["language"] == "en"
|
||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||
assert tokenizer.decode(all_tokens) == result["text"]
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
assert timing["start"] < timing["end"]
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
@ -147,14 +147,31 @@ def load_model(
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
if device == "hpu":
|
||||
"""If the device is HPU,
|
||||
the model should be loaded on CPU first
|
||||
and then moved to HPU."""
|
||||
checkpoint = torch.load(fp, map_location="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model = Whisper(dims, compute_device=torch.device(device))
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
if device == "hpu":
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
|
||||
load_habana_module()
|
||||
if torch.hpu.is_available():
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
model = wrap_in_hpu_graph(model)
|
||||
model = model.eval().to(torch.device(device))
|
||||
|
||||
return model
|
||||
return model.to(device)
|
||||
|
||||
@ -8,6 +8,7 @@ from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
@ -456,6 +457,16 @@ class ApplyTimestampRules(LogitFilter):
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
if is_hpu_device(tokens.device):
|
||||
"""
|
||||
If tokens are on HPU, `sampled_tokens` is cloned to force evaluation.
|
||||
|
||||
On Habana HPUs, tensors may use lazy execution, which can lead to runtime errors if not explicitly
|
||||
evaluated. Cloning `sampled_tokens` ensures it is fully evaluated on the HPU, preventing potential
|
||||
synchronization issues.
|
||||
"""
|
||||
sampled_tokens = tokens[k, self.sample_begin :].clone()
|
||||
else:
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
|
||||
54
whisper/hpu_model_tests.py
Normal file
54
whisper/hpu_model_tests.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(SimpleCNN, self).__init__()
|
||||
|
||||
# Define layers
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.fc1 = nn.Linear(32 * 56 * 56, 128)
|
||||
self.fc2 = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
# Forward pass through the network
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = x.view(-1, 32 * 56 * 56) # Flatten the tensor
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Load Habana module for HPU support
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
import habana_frameworks.torch.hpu as hthpu
|
||||
|
||||
load_habana_module()
|
||||
|
||||
device = None
|
||||
# Set device to HPU
|
||||
if hthpu.is_available():
|
||||
device = torch.device("hpu")
|
||||
print("Using HPU")
|
||||
|
||||
if not device:
|
||||
print("HPU is not available")
|
||||
exit(1)
|
||||
|
||||
# Create model instance and move it to the HPU
|
||||
model = SimpleCNN(num_classes=10).to(device)
|
||||
|
||||
# Create a dummy input tensor and move it to the HPU
|
||||
input_tensor = torch.rand((64, 3, 224, 224), device=device) # Batch size of 64
|
||||
|
||||
# Forward pass through the model on HPU
|
||||
output = model(input_tensor)
|
||||
|
||||
print("Output shape:", output.shape) # Should be [64, num_classes]
|
||||
13
whisper/hpu_utils.py
Normal file
13
whisper/hpu_utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
def get_x_hpu(x_numpy):
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
|
||||
load_habana_module()
|
||||
|
||||
x_hpu = torch.from_numpy(x_numpy).to("hpu")
|
||||
return x_hpu
|
||||
|
||||
|
||||
def is_hpu_device(device: torch.device):
|
||||
return device in (torch.device("hpu:0"), torch.device("hpu"))
|
||||
@ -11,6 +11,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
try:
|
||||
@ -250,9 +251,12 @@ class TextDecoder(nn.Module):
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
def __init__(self, dims: ModelDimensions, compute_device: Optional[torch.device] = None):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.compute_device = compute_device or (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
@ -273,7 +277,11 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
if not is_hpu_device(self.compute_device):
|
||||
# Convert to sparse format if device is not HPU
|
||||
all_heads = all_heads.to_sparse()
|
||||
|
||||
self.register_buffer("alignment_heads", all_heads, persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
@ -282,7 +290,11 @@ class Whisper(nn.Module):
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
if not is_hpu_device(self.compute_device):
|
||||
# Convert to sparse format if device is not HPU
|
||||
mask = mask.to_sparse()
|
||||
|
||||
self.register_buffer("alignment_heads", mask, persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -138,7 +139,54 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw_hpu(x, BLOCK_SIZE=1024):
|
||||
"""
|
||||
DTW implementation for HPU.
|
||||
"""
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
|
||||
# Initialize cost and trace matrices with high values for comparison
|
||||
cost = torch.ones(N + M + 2, M + 2, device="hpu") * np.inf
|
||||
cost[0, 0] = 0 # Start point for DTW
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32, device="hpu")
|
||||
|
||||
for k in range(1, N + M + 1):
|
||||
p0 = cost[k - 1, :M]
|
||||
p1 = cost[k, :M]
|
||||
p2 = cost[k, 1:M + 1]
|
||||
|
||||
c0 = p0.clone()
|
||||
c1 = p1.clone()
|
||||
c2 = p2.clone()
|
||||
|
||||
x_row = x_skew[k - 1, :M]
|
||||
|
||||
cost_row = x_row + torch.min(torch.min(c0, c1), c2)
|
||||
cost[k + 1, 1:M + 1] = cost_row
|
||||
|
||||
# Track path by storing traces
|
||||
trace[k + 1, 1:M + 1] = 2 * (c2 <= c0) * (c2 <= c1) + 1 * (c1 <= c0) * (c1 <= c2) + 0 * (c0 <= c1) * (c0 <= c2)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
try:
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
load_habana_module()
|
||||
|
||||
if torch.hpu.is_available():
|
||||
return dtw_hpu(x)
|
||||
except (ImportError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to import Habana modules, likely due to missing Habana libraries; "
|
||||
)
|
||||
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
@ -204,7 +252,21 @@ def find_alignment(
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
# Adjust alignment head indices for HPU
|
||||
weights = []
|
||||
if is_hpu_device(model.device):
|
||||
# Handle dense layout for HPU
|
||||
alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads
|
||||
indices = alignment_heads_dense.nonzero(as_tuple=True)
|
||||
for _l, _h in zip(*indices):
|
||||
weights.append(QKs[_l][_h])
|
||||
else:
|
||||
# Default behavior for non-HPU devices
|
||||
for _l, _h in model.alignment_heads.indices().T:
|
||||
weights.append(QKs[_l][_h])
|
||||
|
||||
# Stack the weights
|
||||
weights = torch.stack(weights)
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
|
||||
@ -18,6 +18,7 @@ from .audio import (
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
@ -126,6 +127,11 @@ def transcribe(
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if is_hpu_device(model.device):
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on HPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
@ -508,12 +514,26 @@ def cli():
|
||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||
)
|
||||
|
||||
def valid_device(device_name):
|
||||
if device_name == "cuda" and not torch.cuda.is_available():
|
||||
warnings.warn("CUDA is not available; using CPU instead")
|
||||
device_name = "cpu"
|
||||
if device_name == "hpu":
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
|
||||
load_habana_module()
|
||||
if not torch.hpu.is_available():
|
||||
warnings.warn("HPU is not available; using CPU instead")
|
||||
device_name = "hpu"
|
||||
|
||||
return device_name
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=valid_device, help="device to use for PyTorch inference (hpu/cuda/cpu)")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user