This commit is contained in:
Amal Jacob 2025-01-26 21:12:17 -08:00
parent 517a43ecd1
commit 9d8d372a4d

View File

@ -89,6 +89,8 @@ class MultiHeadAttention(nn.Module):
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.attention_scores = None
def forward(
self,
x: Tensor,
@ -109,6 +111,9 @@ class MultiHeadAttention(nn.Module):
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None
return self.out(wv), qk
def qkv_attention(
@ -185,6 +190,8 @@ class AudioEncoder(nn.Module):
)
self.ln_post = LayerNorm(n_state)
self.all_attention_scores = []
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
@ -199,6 +206,8 @@ class AudioEncoder(nn.Module):
for block in self.blocks:
x = block(x)
if block.attn.attention_scores is not None:
self.all_attention_scores.append(block.attn.attention_scores)
x = self.ln_post(x)
return x