From b6e326d1feae35b1285ebe8bcbd06ba629eff9bc Mon Sep 17 00:00:00 2001 From: Amal Jacob Date: Fri, 31 Jan 2025 13:33:13 -0800 Subject: [PATCH] revert --- whisper/model.py | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index b652a1d..cbd1449 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -309,8 +309,8 @@ class Whisper(nn.Module): return attn_maps - def plot_cross_attention_distribution(self, seq_length: int = 1500): - """Plots cross-attention from the decoder, focusing on attention paid to padded regions.""" + def plot_attention_distribution(self, seq_length: int = 100): + """Plots attention distribution over sequence length.""" attn_maps = self.get_attention_weights() if not attn_maps: @@ -318,39 +318,37 @@ class Whisper(nn.Module): return # Convert to NumPy array - attn_maps = np.array(attn_maps) # Shape: (layers, batch, heads, seq_len, seq_len) + attn_maps = np.array(attn_maps) print(f"Attention Maps Shape: {attn_maps.shape}") - # We are interested in **cross-attention** (decoder attending to encoder output) - avg_attn = np.mean(attn_maps, axis=0) # Average over layers - avg_attn = np.squeeze(avg_attn) # Remove batch dim if present - print(f"Averaged Attention Shape: {avg_attn.shape}") + # Average over layers and heads + avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len) + print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}") - # Get attention for each token (rows) across all input positions (columns) - token_attention = np.mean(avg_attn, axis=1) # Shape: (seq_len,) + # Remove batch and singleton dimensions + avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len) + print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}") + + # Check if the array is 2D (seq_len, seq_len) + if avg_attn.ndim == 1: # If still incorrect (seq_len,) + print("Warning: Attention map has only 1 dimension, reshaping for visualization.") + avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape + + # Extract the mean attention for each token + token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,) + print(f"Token Attention Shape: {token_attention.shape}") # Ensure we only plot up to `seq_length` - seq_length = min(seq_length, token_attention.shape[0]) + seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds token_attention = token_attention[:seq_length] - - # Generate X-axis positions x_positions = np.arange(len(token_attention)) - # Create figure + # Plot the attention distribution plt.figure(figsize=(12, 4)) - - # Plot spikes for attention weights - for head in range(attn_maps.shape[2]): # Iterate over heads - plt.plot(x_positions, token_attention, alpha=0.5, linewidth=0.7) - - # Highlight first and last 100 tokens (likely padding) - plt.axvspan(0, 100, color='red', alpha=0.2, label="Padding Zone (Start)") - plt.axvspan(seq_length - 100, seq_length, color='blue', alpha=0.2, label="Padding Zone (End)") - - plt.xlabel("Token Position (Sequence)") + plt.bar(x_positions, token_attention, width=1.5, alpha=0.7) + plt.xlabel("Token Position") plt.ylabel("Attention Score") - plt.title("Decoder Cross-Attention on Padded Tokens") - plt.legend() + plt.title("Attention Distribution Over Sequence") plt.show()