mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
plot cross attention now instead of self attention
This commit is contained in:
parent
5e4fcc115b
commit
9be139c1ec
@ -175,7 +175,6 @@ class ResidualAttentionBlock(nn.Module):
|
||||
cross_out, cross_attn_weights = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||
x = x + cross_out
|
||||
self.cross_attn.attn_weights = cross_attn_weights # Store weights
|
||||
print(self.cross_attn.attn_weights)
|
||||
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
@ -315,7 +314,7 @@ class Whisper(nn.Module):
|
||||
return attn_maps
|
||||
|
||||
def plot_attention_distribution(self, seq_length: int = 1500):
|
||||
"""Plots decoder cross-attention distribution over sequence length."""
|
||||
"""Plots decoder cross-attention distribution over the full 1500-token audio sequence."""
|
||||
attn_maps = self.get_attention_weights()
|
||||
|
||||
if not attn_maps:
|
||||
@ -324,34 +323,26 @@ class Whisper(nn.Module):
|
||||
|
||||
# Convert to NumPy array
|
||||
attn_maps = np.array(attn_maps)
|
||||
print(f"Cross-Attention Maps Shape: {attn_maps.shape}") # (layers, batch, heads, seq_len, audio_seq_len)
|
||||
print(f"Cross-Attention Maps Shape: {attn_maps.shape}") # (layers, batch, heads, 1, audio_seq_len)
|
||||
|
||||
# Average over layers and heads
|
||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Expected shape: (batch, seq_len, audio_seq_len)
|
||||
# Average over layers and heads, but **keep** last two dimensions
|
||||
avg_attn = np.mean(attn_maps, axis=(0, 2)) # Shape: (batch, 1, audio_seq_len)
|
||||
print(f"Averaged Cross-Attention Shape (Before Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Remove batch and singleton dimensions
|
||||
avg_attn = np.squeeze(avg_attn)
|
||||
# Remove singleton dimensions
|
||||
avg_attn = np.squeeze(avg_attn, axis=(0, 1)) # Shape: (audio_seq_len,)
|
||||
print(f"Averaged Cross-Attention Shape (After Squeeze): {avg_attn.shape}")
|
||||
|
||||
# Get attention over **audio sequence (1500 tokens)**
|
||||
real_seq_length = avg_attn.shape[-1] # Ensure we're using full audio sequence
|
||||
# Ensure correct sequence length
|
||||
real_seq_length = avg_attn.shape[0]
|
||||
print(f"Real Sequence Length Detected: {real_seq_length}")
|
||||
|
||||
# Extract mean attention for each audio token
|
||||
token_attention = np.mean(avg_attn, axis=0) # Shape: (audio_seq_len,)
|
||||
print(f"Token Attention Shape: {token_attention.shape}")
|
||||
|
||||
if token_attention.ndim == 0: # Prevents empty scalar error
|
||||
print("Error: token_attention is a scalar. Fixing shape issue.")
|
||||
token_attention = avg_attn.mean(axis=-1) # Alternative averaging
|
||||
|
||||
# Ensure we plot the full available sequence length
|
||||
# Prevent out-of-bounds errors
|
||||
seq_length = min(seq_length, real_seq_length)
|
||||
token_attention = token_attention[:seq_length]
|
||||
x_positions = np.arange(len(token_attention))
|
||||
token_attention = avg_attn[:seq_length]
|
||||
|
||||
# Plot the cross-attention distribution
|
||||
# Plot the attention distribution
|
||||
x_positions = np.arange(len(token_attention))
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.bar(x_positions, token_attention, width=1.5, alpha=0.7)
|
||||
plt.xlabel("Audio Token Position")
|
||||
@ -364,6 +355,7 @@ class Whisper(nn.Module):
|
||||
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
Loading…
x
Reference in New Issue
Block a user