Fix forward method overload in TextDecoder

This commit is contained in:
eleanorTurintech 2025-03-11 10:45:40 +00:00
parent 7a552cb5cc
commit d00d486811

View File

@ -229,91 +229,58 @@ class TextDecoder(nn.Module):
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
"""
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
kv_cache: Optional cache for key/value tensors
Returns:
logits: (n_batch, n_token, n_vocab)
"""
n_batch, n_token = tokens.shape
n_audio_ctx, n_audio_state = audio_features.shape[1:]
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
# Get the dtype of audio_features to ensure consistency
dtype = audio_features.dtype
# Handle kv_cache for token embedding offset
if kv_cache is not None:
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
else:
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
# Convert to the same dtype as audio_features
x = x.to(dtype)
# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()
# Process through attention blocks
for block in self.blocks:
x = block(x, audio_features)
x = block(x, audio_features, kv_cache=kv_cache)
x = self.ln(x)
logits = x @ self.token_embedding.weight.T
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]
logits = logits + mask
return logits
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
Args:
tokens: (n_batch, n_token) or x tensor
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
kv_cache: Optional cache for key/value tensors
"""
if kv_cache is not None:
# Handle the kv_cache case
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(tokens)
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
)
x = x.to(audio_features.dtype)
for block in self.blocks:
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
else:
# Handle the non-kv_cache case
n_batch, n_token = tokens.shape
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
if torch.cuda.is_available():
audio_features = audio_features.cuda()
for block in self.blocks:
x = block(x, audio_features)
x = self.ln(x)
logits = x @ self.token_embedding.weight.T
# Ensure consistent dtype for matrix multiplication
# Convert token_embedding weight to the same dtype as x
embedding_weights = self.token_embedding.weight.to(x.dtype)
logits = x @ embedding_weights.T
# Apply mask if not using kv_cache (inference)
if kv_cache is None:
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]
logits = logits + mask
return logits
return logits
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):