mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Add word_timestamps fix in model.transcribe
This commit is contained in:
parent
6770610528
commit
8c4e65929f
@ -51,7 +51,7 @@ def test_transcribe_hpu(model_name: str):
|
|||||||
|
|
||||||
language = "en" if model_name.endswith(".en") else None
|
language = "en" if model_name.endswith(".en") else None
|
||||||
result = model.transcribe(
|
result = model.transcribe(
|
||||||
audio_path, language=language, temperature=0.0
|
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||||
)
|
)
|
||||||
assert result["language"] == "en"
|
assert result["language"] == "en"
|
||||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||||
|
from .hpu_utils import is_hpu_device
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -251,7 +252,21 @@ def find_alignment(
|
|||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
# Adjust alignment head indices for HPU
|
||||||
|
weights = []
|
||||||
|
if is_hpu_device(model.device):
|
||||||
|
# Handle dense layout for HPU
|
||||||
|
alignment_heads_dense = model.alignment_heads.to_dense() if model.alignment_heads.is_sparse else model.alignment_heads
|
||||||
|
indices = alignment_heads_dense.nonzero(as_tuple=True)
|
||||||
|
for _l, _h in zip(*indices):
|
||||||
|
weights.append(QKs[_l][_h])
|
||||||
|
else:
|
||||||
|
# Default behavior for non-HPU devices
|
||||||
|
for _l, _h in model.alignment_heads.indices().T:
|
||||||
|
weights.append(QKs[_l][_h])
|
||||||
|
|
||||||
|
# Stack the weights
|
||||||
|
weights = torch.stack(weights)
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user