mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
try again
This commit is contained in:
parent
b6e326d1fe
commit
10ecf6ac34
@ -309,7 +309,7 @@ class Whisper(nn.Module):
|
||||
|
||||
return attn_maps
|
||||
|
||||
def plot_attention_distribution(self, seq_length: int = 100):
|
||||
def plot_attention_distribution(self, seq_length: int = 1500):
|
||||
"""Plots attention distribution over sequence length."""
|
||||
attn_maps = self.get_attention_weights()
|
||||
|
||||
@ -322,24 +322,23 @@ class Whisper(nn.Module):
|
||||
print(f"Attention Maps Shape: {attn_maps.shape}")
|
||||
|
||||
# Average over layers and heads
|
||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len)
|
||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, seq_len, seq_len)
|
||||
print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Remove batch and singleton dimensions
|
||||
avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len)
|
||||
print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Check if the array is 2D (seq_len, seq_len)
|
||||
if avg_attn.ndim == 1: # If still incorrect (seq_len,)
|
||||
print("Warning: Attention map has only 1 dimension, reshaping for visualization.")
|
||||
avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape
|
||||
# Check the real sequence length
|
||||
real_seq_length = avg_attn.shape[0]
|
||||
print(f"Real Sequence Length Detected: {real_seq_length}")
|
||||
|
||||
# Extract the mean attention for each token
|
||||
# Extract mean attention for each token
|
||||
token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,)
|
||||
print(f"Token Attention Shape: {token_attention.shape}")
|
||||
|
||||
# Ensure we only plot up to `seq_length`
|
||||
seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds
|
||||
# Ensure we plot the full available sequence length
|
||||
seq_length = min(seq_length, real_seq_length) # Prevent truncation
|
||||
token_attention = token_attention[:seq_length]
|
||||
x_positions = np.arange(len(token_attention))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user