mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Refactor and optimize model code for readability and efficiency
This commit is contained in:
parent
ba3f3cd54b
commit
31a1c816eb
109
whisper/model.py
109
whisper/model.py
@ -35,19 +35,13 @@ class LayerNorm(nn.LayerNorm):
|
|||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.linear(
|
return F.linear(
|
||||||
x,
|
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||||
self.weight.to(x.dtype),
|
|
||||||
None if self.bias is None else self.bias.to(x.dtype),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(nn.Conv1d):
|
class Conv1d(nn.Conv1d):
|
||||||
def _conv_forward(
|
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
||||||
) -> Tensor:
|
|
||||||
return super()._conv_forward(
|
|
||||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sinusoids(length, channels, max_timescale=10000):
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
@ -69,35 +63,26 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.out = Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
||||||
x: Tensor,
|
|
||||||
xa: Optional[Tensor] = None,
|
|
||||||
mask: Optional[Tensor] = None,
|
|
||||||
kv_cache: Optional[dict] = None,
|
|
||||||
):
|
):
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
|
|
||||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
||||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
||||||
k = self.key(x if xa is None else xa)
|
k = self.key(x if xa is None else xa)
|
||||||
v = self.value(x if xa is None else xa)
|
v = self.value(x if xa is None else xa)
|
||||||
else:
|
else:
|
||||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
||||||
k = kv_cache[self.key]
|
k = kv_cache[self.key]
|
||||||
v = kv_cache[self.value]
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv), qk
|
return self.out(wv), qk
|
||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
|
||||||
):
|
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
q = q.view(n_batch, n_ctx, self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
k = k.view(n_batch, n_ctx, self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(n_batch, n_ctx, self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
qk = q @ k
|
qk = q @ k
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -111,13 +96,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(n_state, n_head)
|
self.attn = MultiHeadAttention(n_state, n_head)
|
||||||
self.attn_ln = LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||||
self.cross_attn = (
|
|
||||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
|
||||||
)
|
|
||||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
@ -127,11 +108,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
self.mlp_ln = LayerNorm(n_state)
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
||||||
x: Tensor,
|
|
||||||
xa: Optional[Tensor] = None,
|
|
||||||
mask: Optional[Tensor] = None,
|
|
||||||
kv_cache: Optional[dict] = None,
|
|
||||||
):
|
):
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
@ -141,9 +118,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
@ -155,10 +130,6 @@ class AudioEncoder(nn.Module):
|
|||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
|
||||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
||||||
the mel spectrogram of the audio
|
|
||||||
"""
|
|
||||||
x = F.gelu(self.conv1(x))
|
x = F.gelu(self.conv1(x))
|
||||||
x = F.gelu(self.conv2(x))
|
x = F.gelu(self.conv2(x))
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
@ -174,11 +145,8 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||||
|
|
||||||
@ -194,12 +162,6 @@ class TextDecoder(nn.Module):
|
|||||||
self.register_buffer("mask", mask, persistent=False)
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
"""
|
|
||||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
||||||
the text tokens
|
|
||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
|
||||||
the encoded audio features to be attended on
|
|
||||||
"""
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
x = (
|
x = (
|
||||||
self.token_embedding(x)
|
self.token_embedding(x)
|
||||||
@ -211,10 +173,7 @@ class TextDecoder(nn.Module):
|
|||||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = (
|
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
|
||||||
).float()
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -223,21 +182,13 @@ class Whisper(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.encoder = AudioEncoder(
|
self.encoder = AudioEncoder(
|
||||||
self.dims.n_mels,
|
self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state,
|
||||||
self.dims.n_audio_ctx,
|
self.dims.n_audio_head, self.dims.n_audio_layer
|
||||||
self.dims.n_audio_state,
|
|
||||||
self.dims.n_audio_head,
|
|
||||||
self.dims.n_audio_layer,
|
|
||||||
)
|
)
|
||||||
self.decoder = TextDecoder(
|
self.decoder = TextDecoder(
|
||||||
self.dims.n_vocab,
|
self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state,
|
||||||
self.dims.n_text_ctx,
|
self.dims.n_text_head, self.dims.n_text_layer
|
||||||
self.dims.n_text_state,
|
|
||||||
self.dims.n_text_head,
|
|
||||||
self.dims.n_text_layer,
|
|
||||||
)
|
)
|
||||||
# use the last half among the decoder layers for time alignment by default;
|
|
||||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
@ -245,12 +196,8 @@ class Whisper(nn.Module):
|
|||||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
|
||||||
def set_alignment_heads(self, dump: bytes):
|
def set_alignment_heads(self, dump: bytes):
|
||||||
array = np.frombuffer(
|
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
||||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||||
).copy()
|
|
||||||
mask = torch.from_numpy(array).reshape(
|
|
||||||
self.dims.n_text_layer, self.dims.n_text_head
|
|
||||||
)
|
|
||||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||||
|
|
||||||
def embed_audio(self, mel: torch.Tensor):
|
def embed_audio(self, mel: torch.Tensor):
|
||||||
@ -259,9 +206,7 @@ class Whisper(nn.Module):
|
|||||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
return self.decoder(tokens, audio_features)
|
return self.decoder(tokens, audio_features)
|
||||||
|
|
||||||
def forward(
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
return self.decoder(tokens, self.encoder(mel))
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -277,25 +222,11 @@ class Whisper(nn.Module):
|
|||||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
|
||||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
||||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
||||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
|
||||||
intermediate tensors to be reused during later calculations.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
cache : Dict[nn.Module, torch.Tensor]
|
|
||||||
A dictionary object mapping the key/value projection modules to its cache
|
|
||||||
hooks : List[RemovableHandle]
|
|
||||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
||||||
"""
|
|
||||||
cache = {**cache} if cache is not None else {}
|
cache = {**cache} if cache is not None else {}
|
||||||
hooks = []
|
hooks = []
|
||||||
|
|
||||||
def save_to_cache(module, _, output):
|
def save_to_cache(module, _, output):
|
||||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||||
# save as-is, for the first token or cross attention
|
|
||||||
cache[module] = output
|
cache[module] = output
|
||||||
else:
|
else:
|
||||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user