mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
add back getter
This commit is contained in:
parent
2b86f24780
commit
a85bcadd43
@ -300,6 +300,15 @@ class Whisper(nn.Module):
|
|||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
def get_attention_weights(self):
|
||||||
|
"""Retrieve stored attention weights from the decoder's layers."""
|
||||||
|
attn_maps = []
|
||||||
|
for block in self.decoder.blocks:
|
||||||
|
if block.attn.attn_weights is not None:
|
||||||
|
attn_maps.append(block.attn.attn_weights.detach().cpu().numpy())
|
||||||
|
|
||||||
|
return attn_maps
|
||||||
|
|
||||||
def plot_attention_on_padded(self, seq_length: int = 100):
|
def plot_attention_on_padded(self, seq_length: int = 100):
|
||||||
"""Plots attention weights focusing on padded regions."""
|
"""Plots attention weights focusing on padded regions."""
|
||||||
attn_maps = self.get_attention_weights()
|
attn_maps = self.get_attention_weights()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user