mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge 1aae730936f820a2480bdbf54f8606208b475b0f into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
f5c0c7f31f
@ -192,16 +192,17 @@ def find_alignment(
|
|||||||
]
|
]
|
||||||
|
|
||||||
from .model import disable_sdpa
|
from .model import disable_sdpa
|
||||||
|
|
||||||
with torch.no_grad(), disable_sdpa():
|
try:
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
with torch.no_grad(), disable_sdpa():
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
text_token_probs = text_token_probs.tolist()
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||||
|
text_token_probs = text_token_probs.tolist()
|
||||||
for hook in hooks:
|
finally:
|
||||||
hook.remove()
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user