mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Add word_timestamps fix in model.transcribe
This commit is contained in:
parent
6770610528
commit
8c4e65929f
@ -51,7 +51,7 @@ def test_transcribe_hpu(model_name: str):
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(
|
||||
audio_path, language=language, temperature=0.0
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||
)
|
||||
assert result["language"] == "en"
|
||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .hpu_utils import is_hpu_device
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -251,7 +252,21 @@ def find_alignment(
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
# Adjust alignment head indices for HPU
|
||||
weights = []
|
||||
if is_hpu_device(model.device):
|
||||
# Handle dense layout for HPU
|
||||
alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads
|
||||
indices = alignment_heads_dense.nonzero(as_tuple=True)
|
||||
for _l, _h in zip(*indices):
|
||||
weights.append(QKs[_l][_h])
|
||||
else:
|
||||
# Default behavior for non-HPU devices
|
||||
for _l, _h in model.alignment_heads.indices().T:
|
||||
weights.append(QKs[_l][_h])
|
||||
|
||||
# Stack the weights
|
||||
weights = torch.stack(weights)
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user