mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 07:48:45 +00:00
Refactor and optimize model code for readability and efficiency
This commit is contained in:
parent
31a1c816eb
commit
168306fd3b
@ -34,9 +34,7 @@ class LayerNorm(nn.LayerNorm):
|
|||||||
|
|
||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.linear(
|
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
|
||||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(nn.Conv1d):
|
class Conv1d(nn.Conv1d):
|
||||||
@ -62,9 +60,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.value = Linear(n_state, n_state)
|
self.value = Linear(n_state, n_state)
|
||||||
self.out = Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
|
||||||
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
|
||||||
):
|
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
|
|
||||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
@ -107,9 +103,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.mlp_ln = LayerNorm(n_state)
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
|
||||||
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
|
|
||||||
):
|
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
@ -163,10 +157,7 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
x = (
|
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
self.token_embedding(x)
|
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
|
||||||
)
|
|
||||||
x = x.to(xa.dtype)
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
@ -189,9 +180,7 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state,
|
self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state,
|
||||||
self.dims.n_text_head, self.dims.n_text_layer
|
self.dims.n_text_head, self.dims.n_text_layer
|
||||||
)
|
)
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
|
||||||
)
|
|
||||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user