diff --git a/whisper/model.py b/whisper/model.py index 37eb5d0..dae750a 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -300,6 +300,15 @@ class Whisper(nn.Module): ) -> Dict[str, torch.Tensor]: return self.decoder(tokens, self.encoder(mel)) + def get_attention_weights(self): + """Retrieve stored attention weights from the decoder's layers.""" + attn_maps = [] + for block in self.decoder.blocks: + if block.attn.attn_weights is not None: + attn_maps.append(block.attn.attn_weights.detach().cpu().numpy()) + + return attn_maps + def plot_attention_on_padded(self, seq_length: int = 100): """Plots attention weights focusing on padded regions.""" attn_maps = self.get_attention_weights()