diff --git a/whisper/model.py b/whisper/model.py index dae750a..f20c4e0 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -309,43 +309,39 @@ class Whisper(nn.Module): return attn_maps - def plot_attention_on_padded(self, seq_length: int = 100): - """Plots attention weights focusing on padded regions.""" + def plot_attention_distribution(self, seq_length: int = 100): + """Plots attention distribution over sequence length.""" attn_maps = self.get_attention_weights() if not attn_maps: print("No attention weights found!") return - # Convert list to NumPy array + # Convert to numpy array and print shape attn_maps = np.array(attn_maps) + print(f"Attention Maps Shape: {attn_maps.shape}") - # Print debug info - print(f"Attention Maps Shape (Before Averaging): {attn_maps.shape}") + # Average over layers and heads to get per-token attention + avg_attn = np.mean(attn_maps, axis=(0, 2)) # Shape: (batch, seq_len, seq_len) - # Average over layers and heads - avg_attn = np.mean(attn_maps, axis=(0, 2)) # Shape: (batch, ?, seq_len) + print(f"Averaged Attention Shape: {avg_attn.shape}") - print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}") + # Remove batch dim if present + avg_attn = np.squeeze(avg_attn) - # Squeeze to remove any extra singleton dimensions - avg_attn = np.squeeze(avg_attn) # Removes batch dim + # Extract the attention scores to the first token in each position + token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,) - print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}") + # Ensure we only plot up to `seq_length` + token_attention = token_attention[:seq_length] + x_positions = np.arange(len(token_attention)) - # Ensure correct shape - if avg_attn.ndim == 1: # If still incorrect (seq_len,) - avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape for heatmap - - # Ensure shape matches seq_length - avg_attn = avg_attn[:seq_length, :seq_length] # Truncate to fit expected size - - # Plot heatmap - plt.figure(figsize=(8, 6)) - sns.heatmap(avg_attn, cmap="Blues", annot=False) - plt.xlabel("Input Positions") - plt.ylabel("Output Positions") - plt.title("Attention Weights on Padded Regions") + # Plot attention distribution as spikes + plt.figure(figsize=(12, 4)) + plt.bar(x_positions, token_attention, width=1.5, alpha=0.7) + plt.xlabel("Token Position") + plt.ylabel("Attention Score") + plt.title("Attention Distribution Over Sequence") plt.show()