From 168306fd3b613225f78fc561119018de158a99a4 Mon Sep 17 00:00:00 2001 From: San <99511815+sanowl@users.noreply.github.com> Date: Sat, 22 Jun 2024 18:29:56 +0300 Subject: [PATCH] Refactor and optimize model code for readability and efficiency --- whisper/model.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 90fa906..f6881f8 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -34,9 +34,7 @@ class LayerNorm(nn.LayerNorm): class Linear(nn.Linear): def forward(self, x: Tensor) -> Tensor: - return F.linear( - x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) - ) + return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)) class Conv1d(nn.Conv1d): @@ -62,9 +60,7 @@ class MultiHeadAttention(nn.Module): self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) - def forward( - self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None - ): + def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None): q = self.query(x) if kv_cache is None or xa is None or self.key not in kv_cache: @@ -107,9 +103,7 @@ class ResidualAttentionBlock(nn.Module): ) self.mlp_ln = LayerNorm(n_state) - def forward( - self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None - ): + def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None): x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] @@ -163,10 +157,7 @@ class TextDecoder(nn.Module): def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) + x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] x = x.to(xa.dtype) for block in self.blocks: @@ -189,9 +180,7 @@ class Whisper(nn.Module): self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer ) - all_heads = torch.zeros( - self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - ) + all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) all_heads[self.dims.n_text_layer // 2 :] = True self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)