diff --git a/whisper/model.py b/whisper/model.py index ca3928e..820d3c1 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -62,6 +62,7 @@ class MultiHeadAttention(nn.Module): self.key = Linear(n_state, n_state, bias=False) self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) + self.last_qk = None def forward( self, @@ -96,6 +97,8 @@ class MultiHeadAttention(nn.Module): if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] + self.last_qk = qk.detach() + w = F.softmax(qk.float(), dim=-1).to(q.dtype) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)