mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
revert
This commit is contained in:
parent
b047b5c031
commit
b6e326d1fe
@ -309,8 +309,8 @@ class Whisper(nn.Module):
|
||||
|
||||
return attn_maps
|
||||
|
||||
def plot_cross_attention_distribution(self, seq_length: int = 1500):
|
||||
"""Plots cross-attention from the decoder, focusing on attention paid to padded regions."""
|
||||
def plot_attention_distribution(self, seq_length: int = 100):
|
||||
"""Plots attention distribution over sequence length."""
|
||||
attn_maps = self.get_attention_weights()
|
||||
|
||||
if not attn_maps:
|
||||
@ -318,39 +318,37 @@ class Whisper(nn.Module):
|
||||
return
|
||||
|
||||
# Convert to NumPy array
|
||||
attn_maps = np.array(attn_maps) # Shape: (layers, batch, heads, seq_len, seq_len)
|
||||
attn_maps = np.array(attn_maps)
|
||||
print(f"Attention Maps Shape: {attn_maps.shape}")
|
||||
|
||||
# We are interested in **cross-attention** (decoder attending to encoder output)
|
||||
avg_attn = np.mean(attn_maps, axis=0) # Average over layers
|
||||
avg_attn = np.squeeze(avg_attn) # Remove batch dim if present
|
||||
print(f"Averaged Attention Shape: {avg_attn.shape}")
|
||||
# Average over layers and heads
|
||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, ?, seq_len)
|
||||
print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Get attention for each token (rows) across all input positions (columns)
|
||||
token_attention = np.mean(avg_attn, axis=1) # Shape: (seq_len,)
|
||||
# Remove batch and singleton dimensions
|
||||
avg_attn = np.squeeze(avg_attn) # Shape should now be (seq_len, seq_len)
|
||||
print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Check if the array is 2D (seq_len, seq_len)
|
||||
if avg_attn.ndim == 1: # If still incorrect (seq_len,)
|
||||
print("Warning: Attention map has only 1 dimension, reshaping for visualization.")
|
||||
avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape
|
||||
|
||||
# Extract the mean attention for each token
|
||||
token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,)
|
||||
print(f"Token Attention Shape: {token_attention.shape}")
|
||||
|
||||
# Ensure we only plot up to `seq_length`
|
||||
seq_length = min(seq_length, token_attention.shape[0])
|
||||
seq_length = min(seq_length, token_attention.shape[0]) # Prevent out-of-bounds
|
||||
token_attention = token_attention[:seq_length]
|
||||
|
||||
# Generate X-axis positions
|
||||
x_positions = np.arange(len(token_attention))
|
||||
|
||||
# Create figure
|
||||
# Plot the attention distribution
|
||||
plt.figure(figsize=(12, 4))
|
||||
|
||||
# Plot spikes for attention weights
|
||||
for head in range(attn_maps.shape[2]): # Iterate over heads
|
||||
plt.plot(x_positions, token_attention, alpha=0.5, linewidth=0.7)
|
||||
|
||||
# Highlight first and last 100 tokens (likely padding)
|
||||
plt.axvspan(0, 100, color='red', alpha=0.2, label="Padding Zone (Start)")
|
||||
plt.axvspan(seq_length - 100, seq_length, color='blue', alpha=0.2, label="Padding Zone (End)")
|
||||
|
||||
plt.xlabel("Token Position (Sequence)")
|
||||
plt.bar(x_positions, token_attention, width=1.5, alpha=0.7)
|
||||
plt.xlabel("Token Position")
|
||||
plt.ylabel("Attention Score")
|
||||
plt.title("Decoder Cross-Attention on Padded Tokens")
|
||||
plt.legend()
|
||||
plt.title("Attention Distribution Over Sequence")
|
||||
plt.show()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user