diff --git a/whisper/model.py b/whisper/model.py index 5a247e3..072f0a4 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -229,91 +229,58 @@ class TextDecoder(nn.Module): self.register_buffer("mask_cuda", mask.cuda(), persistent=False) - def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor: """ Args: tokens: (n_batch, n_token) audio_features: (n_batch, n_audio_ctx, n_audio_state) + kv_cache: Optional cache for key/value tensors Returns: logits: (n_batch, n_token, n_vocab) """ n_batch, n_token = tokens.shape - n_audio_ctx, n_audio_state = audio_features.shape[1:] - - x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Get the dtype of audio_features to ensure consistency + dtype = audio_features.dtype + + # Handle kv_cache for token embedding offset + if kv_cache is not None: + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]] + else: + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Convert to the same dtype as audio_features + x = x.to(dtype) # Optimisation: Move audio_features to GPU once here. if torch.cuda.is_available(): audio_features = audio_features.cuda() - + # Process through attention blocks for block in self.blocks: - x = block(x, audio_features) + x = block(x, audio_features, kv_cache=kv_cache) x = self.ln(x) - logits = x @ self.token_embedding.weight.T - - # Optimisation: Apply the precomputed CUDA mask if available. - if torch.cuda.is_available(): - mask = self.mask_cuda[:n_token, :n_token] - else: - mask = self.mask[:n_token, :n_token] - logits = logits + mask - - return logits - - - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): - """ - Args: - tokens: (n_batch, n_token) or x tensor - audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor - kv_cache: Optional cache for key/value tensors - """ - if kv_cache is not None: - # Handle the kv_cache case - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(tokens) - + self.positional_embedding[offset : offset + tokens.shape[-1]] - ) - x = x.to(audio_features.dtype) - - for block in self.blocks: - x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache) - - x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() - - return logits - else: - # Handle the non-kv_cache case - n_batch, n_token = tokens.shape - x = self.token_embedding(tokens) + self.positional_embedding[:n_token] - - if torch.cuda.is_available(): - audio_features = audio_features.cuda() - - for block in self.blocks: - x = block(x, audio_features) - - x = self.ln(x) - logits = x @ self.token_embedding.weight.T + # Ensure consistent dtype for matrix multiplication + # Convert token_embedding weight to the same dtype as x + embedding_weights = self.token_embedding.weight.to(x.dtype) + logits = x @ embedding_weights.T + # Apply mask if not using kv_cache (inference) + if kv_cache is None: + # Optimisation: Apply the precomputed CUDA mask if available. if torch.cuda.is_available(): mask = self.mask_cuda[:n_token, :n_token] else: mask = self.mask[:n_token, :n_token] - + logits = logits + mask - return logits - - + return logits + # The Whisper class has been moved outside of TextDecoder and is now a top-level class class Whisper(nn.Module): def __init__(self, dims: ModelDimensions):