diff --git a/whisper/model.py b/whisper/model.py index 284b291..9a1c250 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,5 +1,7 @@ import base64 import gzip +import matplotlib.pyplot as plt +import seaborn as sns from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, Iterable, Optional, Tuple @@ -110,7 +112,7 @@ class MultiHeadAttention(nn.Module): v = kv_cache[self.value] wv, qk = self.qkv_attention(q, k, v, mask) - + self.attn_weights = qk print(qk) @@ -299,6 +301,34 @@ class Whisper(nn.Module): self, mel: torch.Tensor, tokens: torch.Tensor ) -> Dict[str, torch.Tensor]: return self.decoder(tokens, self.encoder(mel)) + + def get_attention_weights(self): + """Retrieve stored attention weights from the decoder's layers.""" + attn_maps = [] + for block in self.decoder.blocks: + if block.attn.attn_weights is not None: + attn_maps.append(block.attn.attn_weights.detach().cpu().numpy()) + + return attn_maps # Returns a list of attention weight tensors + + def plot_attention_on_padded(self, seq_length: int = 100): + """Plots attention weights focusing on padded regions""" + attn_maps = self.get_attention_weights() + + if not attn_maps: + print("No attention weights found!") + return + + # Average over heads and layers + avg_attn = np.mean(attn_maps, axis=0) + avg_attn = avg_attn[:, :, :seq_length, :seq_length] + + plt.figure(figsize=(8, 6)) + sns.heatmap(avg_attn.mean(axis=1).squeeze(), cmap="Blues", annot=False) + plt.xlabel("Input Positions") + plt.ylabel("Output Positions") + plt.title("Attention Weights on Padded Regions") + plt.show() @property def device(self):