mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
test
This commit is contained in:
parent
517a43ecd1
commit
9d8d372a4d
@ -89,6 +89,8 @@ class MultiHeadAttention(nn.Module):
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
self.attention_scores = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@ -109,6 +111,9 @@ class MultiHeadAttention(nn.Module):
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
|
||||
self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None
|
||||
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
@ -185,6 +190,8 @@ class AudioEncoder(nn.Module):
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
self.all_attention_scores = []
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
@ -199,6 +206,8 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if block.attn.attention_scores is not None:
|
||||
self.all_attention_scores.append(block.attn.attention_scores)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user