try again

This commit is contained in:
Amal Jacob 2025-01-31 13:44:54 -08:00
parent b6e326d1fe
commit 10ecf6ac34

View File

@ -309,7 +309,7 @@ class Whisper(nn.Module):
return attn_maps
def plot_attention_distribution(self, seq_length: int = 100):
def plot_attention_distribution(self, seq_length: int = 1500):
"""Plots attention distribution over sequence length."""
attn_maps = self.get_attention_weights()
@ -322,24 +322,23 @@ class Whisper(nn.Module):
print(f"Attention Maps Shape: {attn_maps.shape}")
# Average over layers and heads
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len)
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, seq_len, seq_len)
print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}")
# Remove batch and singleton dimensions
avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len)
print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}")
# Check if the array is 2D (seq_len, seq_len)
if avg_attn.ndim == 1: # If still incorrect (seq_len,)
print("Warning: Attention map has only 1 dimension, reshaping for visualization.")
avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape
# Check the real sequence length
real_seq_length = avg_attn.shape[0]
print(f"Real Sequence Length Detected: {real_seq_length}")
# Extract the mean attention for each token
# Extract mean attention for each token
token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,)
print(f"Token Attention Shape: {token_attention.shape}")
# Ensure we only plot up to `seq_length`
seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds
# Ensure we plot the full available sequence length
seq_length = min(seq_length, real_seq_length) # Prevent truncation
token_attention = token_attention[:seq_length]
x_positions = np.arange(len(token_attention))