This commit is contained in:
Amal Jacob 2025-01-26 21:30:54 -08:00
parent 34db988568
commit 48151029b0

View File

@ -112,7 +112,7 @@ class MultiHeadAttention(nn.Module):
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy() if qk is not None else None
return self.out(wv), qk return self.out(wv), qk
@ -190,7 +190,7 @@ class AudioEncoder(nn.Module):
) )
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
self.all_attention_scores = [] self.all_attention_scores = None
def forward(self, x: Tensor): def forward(self, x: Tensor):
""" """
@ -204,13 +204,21 @@ class AudioEncoder(nn.Module):
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype) x = (x + self.positional_embedding).to(x.dtype)
attention_list = []
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
if block.attn.attention_scores is not None: if block.attn.attention_scores is not None:
print(f"Captured attention scores from layer {len(self.all_attention_scores)}") print(f"Captured attention scores from layer {len(self.all_attention_scores)}")
self.all_attention_scores.append(block.attn.attention_scores) attention_list.append(block.attn.attention_scores)
x = self.ln_post(x) x = self.ln_post(x)
if attention_list:
self.all_attention_scores = np.array(attention_list)
else:
print("❌ Warning: No attention scores captured. Adding debug placeholder.")
self.all_attention_scores = np.zeros((1, 1))
return x return x