add back getter

This commit is contained in:
Amal Jacob 2025-01-31 12:50:17 -08:00
parent 2b86f24780
commit a85bcadd43

View File

@ -300,6 +300,15 @@ class Whisper(nn.Module):
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def get_attention_weights(self):
"""Retrieve stored attention weights from the decoder's layers."""
attn_maps = []
for block in self.decoder.blocks:
if block.attn.attn_weights is not None:
attn_maps.append(block.attn.attn_weights.detach().cpu().numpy())
return attn_maps
def plot_attention_on_padded(self, seq_length: int = 100):
"""Plots attention weights focusing on padded regions."""
attn_maps = self.get_attention_weights()