From 1aae730936f820a2480bdbf54f8606208b475b0f Mon Sep 17 00:00:00 2001 From: Viraj <77448246+virajsabhaya23@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:34:28 -0500 Subject: [PATCH] fix memory leak: ensure hooks are cleaned up on exception in find_alignment --- whisper/timing.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/whisper/timing.py b/whisper/timing.py index 2340000..e1267c8 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -192,16 +192,17 @@ def find_alignment( ] from .model import disable_sdpa - - with torch.no_grad(), disable_sdpa(): - logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] - sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] - token_probs = sampled_logits.softmax(dim=-1) - text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] - text_token_probs = text_token_probs.tolist() - - for hook in hooks: - hook.remove() + + try: + with torch.no_grad(), disable_sdpa(): + logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] + sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] + token_probs = sampled_logits.softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] + text_token_probs = text_token_probs.tolist() + finally: + for hook in hooks: + hook.remove() # heads * tokens * frames weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])