Refactor and optimize model code for readability and efficiency

This commit is contained in:
San 2024-06-22 18:29:56 +03:00
parent 31a1c816eb
commit 168306fd3b

View File

@ -34,9 +34,7 @@ class LayerNorm(nn.LayerNorm):
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
)
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
class Conv1d(nn.Conv1d):
@ -62,9 +60,7 @@ class MultiHeadAttention(nn.Module):
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
):
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
@ -107,9 +103,7 @@ class ResidualAttentionBlock(nn.Module):
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None
):
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
@ -163,10 +157,7 @@ class TextDecoder(nn.Module):
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
for block in self.blocks:
@ -189,9 +180,7 @@ class Whisper(nn.Module):
self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state,
self.dims.n_text_head, self.dims.n_text_layer
)
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)