Merge 1aae730936f820a2480bdbf54f8606208b475b0f into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Viraj 2025-10-27 19:35:09 -05:00 committed by GitHub
commit f5c0c7f31f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -193,13 +193,14 @@ def find_alignment(
from .model import disable_sdpa from .model import disable_sdpa
try:
with torch.no_grad(), disable_sdpa(): with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1) token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist() text_token_probs = text_token_probs.tolist()
finally:
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()