mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
test
This commit is contained in:
parent
34db988568
commit
48151029b0
@ -112,7 +112,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
|
||||
self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None
|
||||
self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy() if qk is not None else None
|
||||
|
||||
return self.out(wv), qk
|
||||
|
||||
@ -190,7 +190,7 @@ class AudioEncoder(nn.Module):
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
self.all_attention_scores = []
|
||||
self.all_attention_scores = None
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
@ -204,13 +204,21 @@ class AudioEncoder(nn.Module):
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
attention_list = []
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if block.attn.attention_scores is not None:
|
||||
print(f"Captured attention scores from layer {len(self.all_attention_scores)}")
|
||||
self.all_attention_scores.append(block.attn.attention_scores)
|
||||
attention_list.append(block.attn.attention_scores)
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if attention_list:
|
||||
self.all_attention_scores = np.array(attention_list)
|
||||
else:
|
||||
print("❌ Warning: No attention scores captured. Adding debug placeholder.")
|
||||
self.all_attention_scores = np.zeros((1, 1))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user