mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
found attention
This commit is contained in:
parent
60e1a31a9e
commit
54edc865e6
@ -88,8 +88,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
self.attention_scores = None
|
||||
self.attn_weights = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -111,12 +110,9 @@ class MultiHeadAttention(nn.Module):
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
|
||||
if qk is not None:
|
||||
print(f"✅ Attention shape: {qk.shape}") # Should print the shape of attention weights
|
||||
self.attention_scores = qk.softmax(dim=-1).detach().cpu().numpy()
|
||||
else:
|
||||
print("❌ No attention computed in MultiHeadAttention!")
|
||||
|
||||
self.attn_weights = qk
|
||||
print(qk)
|
||||
|
||||
return self.out(wv), qk
|
||||
|
||||
@ -133,8 +129,8 @@ class MultiHeadAttention(nn.Module):
|
||||
a = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||
)
|
||||
qk = (q @ k.transpose(-1, -2)).softmax(dim=-1)
|
||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = None
|
||||
else:
|
||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||
if mask is not None:
|
||||
@ -143,8 +139,8 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = qk.detach()
|
||||
|
||||
self.attn_weights = qk.detach()
|
||||
return out, qk
|
||||
|
||||
|
||||
@ -194,8 +190,6 @@ class AudioEncoder(nn.Module):
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
self.all_attention_scores = None
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
@ -208,24 +202,10 @@ class AudioEncoder(nn.Module):
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
attention_list = []
|
||||
for layer_idx, block in enumerate(self.blocks):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if block.attn.attention_scores is not None:
|
||||
print(f"✅ Captured attention scores from layer {layer_idx}")
|
||||
attention_list.append(block.attn.attention_scores)
|
||||
else:
|
||||
print(f"❌ No attention captured at layer {layer_idx}!")
|
||||
|
||||
x = self.ln_post(x)
|
||||
# ✅ Debug: If no attention scores were captured, add a debug placeholder
|
||||
if attention_list:
|
||||
self.all_attention_scores = np.array(attention_list)
|
||||
else:
|
||||
print("❌ Warning: No attention scores captured at all layers. Adding debug placeholder.")
|
||||
self.all_attention_scores = np.zeros((1, 1))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -367,4 +347,4 @@ class Whisper(nn.Module):
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
decode = decode_function
|
Loading…
x
Reference in New Issue
Block a user