mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 15:54:00 +00:00
add plot
This commit is contained in:
parent
54edc865e6
commit
59b1fa70fe
@ -1,5 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional, Tuple
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
@ -110,7 +112,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
v = kv_cache[self.value]
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
|
|
||||||
self.attn_weights = qk
|
self.attn_weights = qk
|
||||||
print(qk)
|
print(qk)
|
||||||
|
|
||||||
@ -299,6 +301,34 @@ class Whisper(nn.Module):
|
|||||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
def get_attention_weights(self):
|
||||||
|
"""Retrieve stored attention weights from the decoder's layers."""
|
||||||
|
attn_maps = []
|
||||||
|
for block in self.decoder.blocks:
|
||||||
|
if block.attn.attn_weights is not None:
|
||||||
|
attn_maps.append(block.attn.attn_weights.detach().cpu().numpy())
|
||||||
|
|
||||||
|
return attn_maps # Returns a list of attention weight tensors
|
||||||
|
|
||||||
|
def plot_attention_on_padded(self, seq_length: int = 100):
|
||||||
|
"""Plots attention weights focusing on padded regions"""
|
||||||
|
attn_maps = self.get_attention_weights()
|
||||||
|
|
||||||
|
if not attn_maps:
|
||||||
|
print("No attention weights found!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Average over heads and layers
|
||||||
|
avg_attn = np.mean(attn_maps, axis=0)
|
||||||
|
avg_attn = avg_attn[:, :, :seq_length, :seq_length]
|
||||||
|
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
sns.heatmap(avg_attn.mean(axis=1).squeeze(), cmap="Blues", annot=False)
|
||||||
|
plt.xlabel("Input Positions")
|
||||||
|
plt.ylabel("Output Positions")
|
||||||
|
plt.title("Attention Weights on Padded Regions")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user