From 9d8d372a4db955b963f435befd4c498be0f84f96 Mon Sep 17 00:00:00 2001 From: Amal Jacob Date: Sun, 26 Jan 2025 21:12:17 -0800 Subject: [PATCH] test --- whisper/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/whisper/model.py b/whisper/model.py index e537447..d44b9a1 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -89,6 +89,8 @@ class MultiHeadAttention(nn.Module): self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) + self.attention_scores = None + def forward( self, x: Tensor, @@ -109,6 +111,9 @@ class MultiHeadAttention(nn.Module): v = kv_cache[self.value] wv, qk = self.qkv_attention(q, k, v, mask) + + self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None + return self.out(wv), qk def qkv_attention( @@ -185,6 +190,8 @@ class AudioEncoder(nn.Module): ) self.ln_post = LayerNorm(n_state) + self.all_attention_scores = [] + def forward(self, x: Tensor): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) @@ -199,6 +206,8 @@ class AudioEncoder(nn.Module): for block in self.blocks: x = block(x) + if block.attn.attention_scores is not None: + self.all_attention_scores.append(block.attn.attention_scores) x = self.ln_post(x) return x