This commit is contained in:
Amal Jacob 2025-01-31 13:33:13 -08:00
parent b047b5c031
commit b6e326d1fe

View File

@ -309,8 +309,8 @@ class Whisper(nn.Module):
return attn_maps
def plot_cross_attention_distribution(self, seq_length: int = 1500):
"""Plots cross-attention from the decoder, focusing on attention paid to padded regions."""
def plot_attention_distribution(self, seq_length: int = 100):
"""Plots attention distribution over sequence length."""
attn_maps = self.get_attention_weights()
if not attn_maps:
@ -318,39 +318,37 @@ class Whisper(nn.Module):
return
# Convert to NumPy array
attn_maps = np.array(attn_maps) # Shape: (layers, batch, heads, seq_len, seq_len)
attn_maps = np.array(attn_maps)
print(f"Attention Maps Shape: {attn_maps.shape}")
# We are interested in **cross-attention** (decoder attending to encoder output)
avg_attn = np.mean(attn_maps, axis=0) # Average over layers
avg_attn = np.squeeze(avg_attn) # Remove batch dim if present
print(f"Averaged Attention Shape: {avg_attn.shape}")
# Average over layers and heads
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len)
print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}")
# Get attention for each token (rows) across all input positions (columns)
token_attention = np.mean(avg_attn, axis=1) # Shape: (seq_len,)
# Remove batch and singleton dimensions
avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len)
print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}")
# Check if the array is 2D (seq_len, seq_len)
if avg_attn.ndim == 1: # If still incorrect (seq_len,)
print("Warning: Attention map has only 1 dimension, reshaping for visualization.")
avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape
# Extract the mean attention for each token
token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,)
print(f"Token Attention Shape: {token_attention.shape}")
# Ensure we only plot up to `seq_length`
seq_length = min(seq_length, token_attention.shape[0])
seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds
token_attention = token_attention[:seq_length]
# Generate X-axis positions
x_positions = np.arange(len(token_attention))
# Create figure
# Plot the attention distribution
plt.figure(figsize=(12, 4))
# Plot spikes for attention weights
for head in range(attn_maps.shape[2]): # Iterate over heads
plt.plot(x_positions, token_attention, alpha=0.5, linewidth=0.7)
# Highlight first and last 100 tokens (likely padding)
plt.axvspan(0, 100, color='red', alpha=0.2, label="Padding Zone (Start)")
plt.axvspan(seq_length - 100, seq_length, color='blue', alpha=0.2, label="Padding Zone (End)")
plt.xlabel("Token Position (Sequence)")
plt.bar(x_positions, token_attention, width=1.5, alpha=0.7)
plt.xlabel("Token Position")
plt.ylabel("Attention Score")
plt.title("Decoder Cross-Attention on Padded Tokens")
plt.legend()
plt.title("Attention Distribution Over Sequence")
plt.show()