Add word_timestamps fix in model.transcribe

This commit is contained in:
PiotrBLL 2024-11-20 00:58:28 +01:00
parent 6770610528
commit 8c4e65929f
2 changed files with 17 additions and 2 deletions

View File

@ -51,7 +51,7 @@ def test_transcribe_hpu(model_name: str):
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(
audio_path, language=language, temperature=0.0
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])

View File

@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .hpu_utils import is_hpu_device
from .tokenizer import Tokenizer
if TYPE_CHECKING:
@ -251,7 +252,21 @@ def find_alignment(
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
# Adjust alignment head indices for HPU
weights = []
if is_hpu_device(model.device):
# Handle dense layout for HPU
alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads
indices = alignment_heads_dense.nonzero(as_tuple=True)
for _l, _h in zip(*indices):
weights.append(QKs[_l][_h])
else:
# Default behavior for non-HPU devices
for _l, _h in model.alignment_heads.indices().T:
weights.append(QKs[_l][_h])
# Stack the weights
weights = torch.stack(weights)
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)