found attention

This commit is contained in:
Amal Jacob 2025-01-31 12:22:51 -08:00
parent 60e1a31a9e
commit 54edc865e6

View File

@ -88,8 +88,7 @@ class MultiHeadAttention(nn.Module):
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.attention_scores = None
self.attn_weights = None
def forward(
self,
@ -111,12 +110,9 @@ class MultiHeadAttention(nn.Module):
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
if qk is not None:
print(f"✅ Attention shape: {qk.shape}") # Should print the shape of attention weights
self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy()
else:
print("❌ No attention computed in MultiHeadAttention!")
self.attn_weights = qk
print(qk)
return self.out(wv), qk
@ -133,8 +129,8 @@ class MultiHeadAttention(nn.Module):
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
qk = (q @ k.transpose(-1, -2)).softmax(dim=-1)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
@ -143,8 +139,8 @@ class MultiHeadAttention(nn.Module):
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
self.attn_weights = qk.detach()
return out, qk
@ -194,8 +190,6 @@ class AudioEncoder(nn.Module):
)
self.ln_post = LayerNorm(n_state)
self.all_attention_scores = None
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
@ -208,24 +202,10 @@ class AudioEncoder(nn.Module):
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
attention_list = []
for layer_idx, block in enumerate(self.blocks):
for block in self.blocks:
x = block(x)
if block.attn.attention_scores is not None:
print(f"✅ Captured attention scores from layer {layer_idx}")
attention_list.append(block.attn.attention_scores)
else:
print(f"❌ No attention captured at layer {layer_idx}!")
x = self.ln_post(x)
# ✅ Debug: If no attention scores were captured, add a debug placeholder
if attention_list:
self.all_attention_scores = np.array(attention_list)
else:
print("❌ Warning: No attention scores captured at all layers. Adding debug placeholder.")
self.all_attention_scores = np.zeros((1, 1))
return x
@ -367,4 +347,4 @@ class Whisper(nn.Module):
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
decode = decode_function