fix memory leak: ensure hooks are cleaned up on exception in find_alignment

This commit is contained in:
Viraj 2025-10-27 19:34:28 -05:00
parent c0d2f624c0
commit 1aae730936

View File

@ -193,15 +193,16 @@ def find_alignment(
from .model import disable_sdpa from .model import disable_sdpa
with torch.no_grad(), disable_sdpa(): try:
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] with torch.no_grad(), disable_sdpa():
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = sampled_logits.softmax(dim=-1) sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = text_token_probs.tolist() text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks: finally:
hook.remove() for hook in hooks:
hook.remove()
# heads * tokens * frames # heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])