diff --git a/whisper/timing.py b/whisper/timing.py index 2340000..e1267c8 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -192,16 +192,17 @@ def find_alignment( ] from .model import disable_sdpa - - with torch.no_grad(), disable_sdpa(): - logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] - sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] - token_probs = sampled_logits.softmax(dim=-1) - text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] - text_token_probs = text_token_probs.tolist() - - for hook in hooks: - hook.remove() + + try: + with torch.no_grad(), disable_sdpa(): + logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] + sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] + token_probs = sampled_logits.softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] + text_token_probs = text_token_probs.tolist() + finally: + for hook in hooks: + hook.remove() # heads * tokens * frames weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])