mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 15:54:00 +00:00
Refactor and optimize model code for readability and efficiency
This commit is contained in:
parent
31a1c816eb
commit
168306fd3b
@ -34,9 +34,7 @@ class LayerNorm(nn.LayerNorm):
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
)
|
||||
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
@ -62,9 +60,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
||||
):
|
||||
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
@ -107,9 +103,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
||||
):
|
||||
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
@ -163,10 +157,7 @@ class TextDecoder(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
@ -189,9 +180,7 @@ class Whisper(nn.Module):
|
||||
self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state,
|
||||
self.dims.n_text_head, self.dims.n_text_layer
|
||||
)
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user