diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 175be91..c5a7e6e 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -51,7 +51,7 @@ def test_transcribe_hpu(model_name: str): language = "en" if model_name.endswith(".en") else None result = model.transcribe( - audio_path, language=language, temperature=0.0 + audio_path, language=language, temperature=0.0, word_timestamps=True ) assert result["language"] == "en" assert result["text"] == "".join([s["text"] for s in result["segments"]]) diff --git a/whisper/timing.py b/whisper/timing.py index dec2081..7e9763d 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND +from .hpu_utils import is_hpu_device from .tokenizer import Tokenizer if TYPE_CHECKING: @@ -251,7 +252,21 @@ def find_alignment( hook.remove() # heads * tokens * frames - weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) + # Adjust alignment head indices for HPU + weights = [] + if is_hpu_device(model.device): + # Handle dense layout for HPU + alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads + indices = alignment_heads_dense.nonzero(as_tuple=True) + for _l, _h in zip(*indices): + weights.append(QKs[_l][_h]) + else: + # Default behavior for non-HPU devices + for _l, _h in model.alignment_heads.indices().T: + weights.append(QKs[_l][_h]) + + # Stack the weights + weights = torch.stack(weights) weights = weights[:, :, : num_frames // 2] weights = (weights * qk_scale).softmax(dim=-1) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)