From 48151029b01c7e08c607bc2deccd276ba2ade60d Mon Sep 17 00:00:00 2001 From: Amal Jacob Date: Sun, 26 Jan 2025 21:30:54 -0800 Subject: [PATCH] test --- whisper/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index ab7df23..542dae4 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -112,7 +112,7 @@ class MultiHeadAttention(nn.Module): wv, qk = self.qkv_attention(q, k, v, mask) - self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None + self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy() if qk is not None else None return self.out(wv), qk @@ -190,7 +190,7 @@ class AudioEncoder(nn.Module): ) self.ln_post = LayerNorm(n_state) - self.all_attention_scores = [] + self.all_attention_scores = None def forward(self, x: Tensor): """ @@ -204,13 +204,21 @@ class AudioEncoder(nn.Module): assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" x = (x + self.positional_embedding).to(x.dtype) + attention_list = [] for block in self.blocks: x = block(x) if block.attn.attention_scores is not None: print(f"Captured attention scores from layer {len(self.all_attention_scores)}") - self.all_attention_scores.append(block.attn.attention_scores) + attention_list.append(block.attn.attention_scores) x = self.ln_post(x) + + if attention_list: + self.all_attention_scores = np.array(attention_list) + else: + print("❌ Warning: No attention scores captured. Adding debug placeholder.") + self.all_attention_scores = np.zeros((1, 1)) + return x