mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Fix forward method overload in TextDecoder
This commit is contained in:
parent
7a552cb5cc
commit
d00d486811
@ -229,91 +229,58 @@ class TextDecoder(nn.Module):
|
||||
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
||||
|
||||
|
||||
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens: (n_batch, n_token)
|
||||
audio_features: (n_batch, n_audio_ctx, n_audio_state)
|
||||
kv_cache: Optional cache for key/value tensors
|
||||
|
||||
Returns:
|
||||
logits: (n_batch, n_token, n_vocab)
|
||||
"""
|
||||
n_batch, n_token = tokens.shape
|
||||
n_audio_ctx, n_audio_state = audio_features.shape[1:]
|
||||
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||
|
||||
# Get the dtype of audio_features to ensure consistency
|
||||
dtype = audio_features.dtype
|
||||
|
||||
# Handle kv_cache for token embedding offset
|
||||
if kv_cache is not None:
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
|
||||
else:
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||
|
||||
# Convert to the same dtype as audio_features
|
||||
x = x.to(dtype)
|
||||
|
||||
# Optimisation: Move audio_features to GPU once here.
|
||||
if torch.cuda.is_available():
|
||||
audio_features = audio_features.cuda()
|
||||
|
||||
|
||||
# Process through attention blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, audio_features)
|
||||
x = block(x, audio_features, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = x @ self.token_embedding.weight.T
|
||||
|
||||
# Optimisation: Apply the precomputed CUDA mask if available.
|
||||
if torch.cuda.is_available():
|
||||
mask = self.mask_cuda[:n_token, :n_token]
|
||||
else:
|
||||
mask = self.mask[:n_token, :n_token]
|
||||
|
||||
logits = logits + mask
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
Args:
|
||||
tokens: (n_batch, n_token) or x tensor
|
||||
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
|
||||
kv_cache: Optional cache for key/value tensors
|
||||
"""
|
||||
if kv_cache is not None:
|
||||
# Handle the kv_cache case
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(tokens)
|
||||
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
|
||||
)
|
||||
x = x.to(audio_features.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
else:
|
||||
# Handle the non-kv_cache case
|
||||
n_batch, n_token = tokens.shape
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
audio_features = audio_features.cuda()
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, audio_features)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = x @ self.token_embedding.weight.T
|
||||
# Ensure consistent dtype for matrix multiplication
|
||||
# Convert token_embedding weight to the same dtype as x
|
||||
embedding_weights = self.token_embedding.weight.to(x.dtype)
|
||||
logits = x @ embedding_weights.T
|
||||
|
||||
# Apply mask if not using kv_cache (inference)
|
||||
if kv_cache is None:
|
||||
# Optimisation: Apply the precomputed CUDA mask if available.
|
||||
if torch.cuda.is_available():
|
||||
mask = self.mask_cuda[:n_token, :n_token]
|
||||
else:
|
||||
mask = self.mask[:n_token, :n_token]
|
||||
|
||||
|
||||
logits = logits + mask
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
return logits
|
||||
|
||||
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
|
Loading…
x
Reference in New Issue
Block a user