mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
test
This commit is contained in:
parent
34db988568
commit
48151029b0
@ -112,7 +112,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
|
|
||||||
self.attention_scores = qk.detach().cpu().numpy() if qk is not None else None
|
self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy() if qk is not None else None
|
||||||
|
|
||||||
return self.out(wv), qk
|
return self.out(wv), qk
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class AudioEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
self.all_attention_scores = []
|
self.all_attention_scores = None
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
@ -204,13 +204,21 @@ class AudioEncoder(nn.Module):
|
|||||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
x = (x + self.positional_embedding).to(x.dtype)
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
|
attention_list = []
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if block.attn.attention_scores is not None:
|
if block.attn.attention_scores is not None:
|
||||||
print(f"Captured attention scores from layer {len(self.all_attention_scores)}")
|
print(f"Captured attention scores from layer {len(self.all_attention_scores)}")
|
||||||
self.all_attention_scores.append(block.attn.attention_scores)
|
attention_list.append(block.attn.attention_scores)
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
|
|
||||||
|
if attention_list:
|
||||||
|
self.all_attention_scores = np.array(attention_list)
|
||||||
|
else:
|
||||||
|
print("❌ Warning: No attention scores captured. Adding debug placeholder.")
|
||||||
|
self.all_attention_scores = np.zeros((1, 1))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user