This commit is contained in:
Amal Jacob 2025-01-31 12:30:45 -08:00
parent 54edc865e6
commit 59b1fa70fe

View File

@ -1,5 +1,7 @@
import base64 import base64
import gzip import gzip
import matplotlib.pyplot as plt
import seaborn as sns
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple from typing import Dict, Iterable, Optional, Tuple
@ -110,7 +112,7 @@ class MultiHeadAttention(nn.Module):
v = kv_cache[self.value] v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
self.attn_weights = qk self.attn_weights = qk
print(qk) print(qk)
@ -299,6 +301,34 @@ class Whisper(nn.Module):
self, mel: torch.Tensor, tokens: torch.Tensor self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel)) return self.decoder(tokens, self.encoder(mel))
def get_attention_weights(self):
"""Retrieve stored attention weights from the decoder's layers."""
attn_maps = []
for block in self.decoder.blocks:
if block.attn.attn_weights is not None:
attn_maps.append(block.attn.attn_weights.detach().cpu().numpy())
return attn_maps # Returns a list of attention weight tensors
def plot_attention_on_padded(self, seq_length: int = 100):
"""Plots attention weights focusing on padded regions"""
attn_maps = self.get_attention_weights()
if not attn_maps:
print("No attention weights found!")
return
# Average over heads and layers
avg_attn = np.mean(attn_maps, axis=0)
avg_attn = avg_attn[:, :, :seq_length, :seq_length]
plt.figure(figsize=(8, 6))
sns.heatmap(avg_attn.mean(axis=1).squeeze(), cmap="Blues", annot=False)
plt.xlabel("Input Positions")
plt.ylabel("Output Positions")
plt.title("Attention Weights on Padded Regions")
plt.show()
@property @property
def device(self): def device(self):