mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
update plot
This commit is contained in:
parent
a85bcadd43
commit
97b718e75c
@ -309,43 +309,39 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
return attn_maps
|
return attn_maps
|
||||||
|
|
||||||
def plot_attention_on_padded(self, seq_length: int = 100):
|
def plot_attention_distribution(self, seq_length: int = 100):
|
||||||
"""Plots attention weights focusing on padded regions."""
|
"""Plots attention distribution over sequence length."""
|
||||||
attn_maps = self.get_attention_weights()
|
attn_maps = self.get_attention_weights()
|
||||||
|
|
||||||
if not attn_maps:
|
if not attn_maps:
|
||||||
print("No attention weights found!")
|
print("No attention weights found!")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert list to NumPy array
|
# Convert to numpy array and print shape
|
||||||
attn_maps = np.array(attn_maps)
|
attn_maps = np.array(attn_maps)
|
||||||
|
print(f"Attention Maps Shape: {attn_maps.shape}")
|
||||||
|
|
||||||
# Print debug info
|
# Average over layers and heads to get per-token attention
|
||||||
print(f"Attention Maps Shape (Before Averaging): {attn_maps.shape}")
|
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Shape: (batch, seq_len, seq_len)
|
||||||
|
|
||||||
# Average over layers and heads
|
print(f"Averaged Attention Shape: {avg_attn.shape}")
|
||||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Shape: (batch, ?, seq_len)
|
|
||||||
|
|
||||||
print(f"Averaged Attention Shape (Before Squeeze): {avg_attn.shape}")
|
# Remove batch dim if present
|
||||||
|
avg_attn = np.squeeze(avg_attn)
|
||||||
|
|
||||||
# Squeeze to remove any extra singleton dimensions
|
# Extract the attention scores to the first token in each position
|
||||||
avg_attn = np.squeeze(avg_attn) # Removes batch dim
|
token_attention = np.mean(avg_attn, axis=0) # Shape: (seq_len,)
|
||||||
|
|
||||||
print(f"Averaged Attention Shape (After Squeeze): {avg_attn.shape}")
|
# Ensure we only plot up to `seq_length`
|
||||||
|
token_attention = token_attention[:seq_length]
|
||||||
|
x_positions = np.arange(len(token_attention))
|
||||||
|
|
||||||
# Ensure correct shape
|
# Plot attention distribution as spikes
|
||||||
if avg_attn.ndim == 1: # If still incorrect (seq_len,)
|
plt.figure(figsize=(12, 4))
|
||||||
avg_attn = avg_attn.reshape((1, -1)) # Force into 2D shape for heatmap
|
plt.bar(x_positions, token_attention, width=1.5, alpha=0.7)
|
||||||
|
plt.xlabel("Token Position")
|
||||||
# Ensure shape matches seq_length
|
plt.ylabel("Attention Score")
|
||||||
avg_attn = avg_attn[:seq_length, :seq_length] # Truncate to fit expected size
|
plt.title("Attention Distribution Over Sequence")
|
||||||
|
|
||||||
# Plot heatmap
|
|
||||||
plt.figure(figsize=(8, 6))
|
|
||||||
sns.heatmap(avg_attn, cmap="Blues", annot=False)
|
|
||||||
plt.xlabel("Input Positions")
|
|
||||||
plt.ylabel("Output Positions")
|
|
||||||
plt.title("Attention Weights on Padded Regions")
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user