From 10ecf6ac344e18aa9c109a41aa1abd721820e69c Mon Sep 17 00:00:00 2001 From: Amal Jacob Date: Fri, 31 Jan 2025 13:44:54 -0800 Subject: [PATCH] try again --- whisper/model.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index cbd1449..bef41ee 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -309,7 +309,7 @@ class Whisper(nn.Module): return attn_maps - def plot_attention_distribution(self, seq_length: int = 100): + def plot_attention_distribution(self, seq_length: int = 1500): """Plots attention distribution over sequence length.""" attn_maps = self.get_attention_weights() @@ -322,24 +322,23 @@ class Whisper(nn.Module): print(f"Attention Maps Shape: {attn_maps.shape}") # Average over layers and heads - avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len) + avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, seq_len, seq_len) print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}") # Remove batch and singleton dimensions avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len) print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}") - # Check if the array is 2D (seq_len, seq_len) - if avg_attn.ndim == 1: # If still incorrect (seq_len,) - print("Warning: Attention map has only 1 dimension, reshaping for visualization.") - avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape + # Check the real sequence length + real_seq_length = avg_attn.shape[0] + print(f"Real Sequence Length Detected: {real_seq_length}") - # Extract the mean attention for each token + # Extract mean attention for each token token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,) print(f"Token Attention Shape: {token_attention.shape}") - # Ensure we only plot up to `seq_length` - seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds + # Ensure we plot the full available sequence length + seq_length = min(seq_length, real_seq_length) # Prevent truncation token_attention = token_attention[:seq_length] x_positions = np.arange(len(token_attention))