This commit is contained in:
Amal Jacob 2025-01-31 12:30:45 -08:00
parent 54edc865e6
commit 59b1fa70fe

View File

@ -1,5 +1,7 @@
import base64
import gzip
import matplotlib.pyplot as plt
import seaborn as sns
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple
@ -300,6 +302,34 @@ class Whisper(nn.Module):
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def get_attention_weights(self):
"""Retrieve stored attention weights from the decoder's layers."""
attn_maps = []
for block in self.decoder.blocks:
if block.attn.attn_weights is not None:
attn_maps.append(block.attn.attn_weights.detach().cpu().numpy())
return attn_maps # Returns a list of attention weight tensors
def plot_attention_on_padded(self, seq_length: int = 100):
"""Plots attention weights focusing on padded regions"""
attn_maps = self.get_attention_weights()
if not attn_maps:
print("No attention weights found!")
return
# Average over heads and layers
avg_attn = np.mean(attn_maps, axis=0)
avg_attn = avg_attn[:, :, :seq_length, :seq_length]
plt.figure(figsize=(8, 6))
sns.heatmap(avg_attn.mean(axis=1).squeeze(), cmap="Blues", annot=False)
plt.xlabel("Input Positions")
plt.ylabel("Output Positions")
plt.title("Attention Weights on Padded Regions")
plt.show()
@property
def device(self):
return next(self.parameters()).device