mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 16:14:00 +00:00
Fix forward method overload in TextDecoder
This commit is contained in:
parent
7a552cb5cc
commit
d00d486811
@ -229,91 +229,58 @@ class TextDecoder(nn.Module):
|
|||||||
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens: (n_batch, n_token)
|
tokens: (n_batch, n_token)
|
||||||
audio_features: (n_batch, n_audio_ctx, n_audio_state)
|
audio_features: (n_batch, n_audio_ctx, n_audio_state)
|
||||||
|
kv_cache: Optional cache for key/value tensors
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits: (n_batch, n_token, n_vocab)
|
logits: (n_batch, n_token, n_vocab)
|
||||||
"""
|
"""
|
||||||
n_batch, n_token = tokens.shape
|
n_batch, n_token = tokens.shape
|
||||||
n_audio_ctx, n_audio_state = audio_features.shape[1:]
|
|
||||||
|
# Get the dtype of audio_features to ensure consistency
|
||||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
dtype = audio_features.dtype
|
||||||
|
|
||||||
|
# Handle kv_cache for token embedding offset
|
||||||
|
if kv_cache is not None:
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
|
||||||
|
else:
|
||||||
|
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||||
|
|
||||||
|
# Convert to the same dtype as audio_features
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
# Optimisation: Move audio_features to GPU once here.
|
# Optimisation: Move audio_features to GPU once here.
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
audio_features = audio_features.cuda()
|
audio_features = audio_features.cuda()
|
||||||
|
|
||||||
|
# Process through attention blocks
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, audio_features)
|
x = block(x, audio_features, kv_cache=kv_cache)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = x @ self.token_embedding.weight.T
|
|
||||||
|
|
||||||
# Optimisation: Apply the precomputed CUDA mask if available.
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
mask = self.mask_cuda[:n_token, :n_token]
|
|
||||||
else:
|
|
||||||
mask = self.mask[:n_token, :n_token]
|
|
||||||
|
|
||||||
logits = logits + mask
|
# Ensure consistent dtype for matrix multiplication
|
||||||
|
# Convert token_embedding weight to the same dtype as x
|
||||||
return logits
|
embedding_weights = self.token_embedding.weight.to(x.dtype)
|
||||||
|
logits = x @ embedding_weights.T
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
tokens: (n_batch, n_token) or x tensor
|
|
||||||
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
|
|
||||||
kv_cache: Optional cache for key/value tensors
|
|
||||||
"""
|
|
||||||
if kv_cache is not None:
|
|
||||||
# Handle the kv_cache case
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
||||||
x = (
|
|
||||||
self.token_embedding(tokens)
|
|
||||||
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
|
|
||||||
)
|
|
||||||
x = x.to(audio_features.dtype)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)
|
|
||||||
|
|
||||||
x = self.ln(x)
|
|
||||||
logits = (
|
|
||||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
|
||||||
).float()
|
|
||||||
|
|
||||||
return logits
|
|
||||||
else:
|
|
||||||
# Handle the non-kv_cache case
|
|
||||||
n_batch, n_token = tokens.shape
|
|
||||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
audio_features = audio_features.cuda()
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x, audio_features)
|
|
||||||
|
|
||||||
x = self.ln(x)
|
|
||||||
logits = x @ self.token_embedding.weight.T
|
|
||||||
|
|
||||||
|
# Apply mask if not using kv_cache (inference)
|
||||||
|
if kv_cache is None:
|
||||||
|
# Optimisation: Apply the precomputed CUDA mask if available.
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
mask = self.mask_cuda[:n_token, :n_token]
|
mask = self.mask_cuda[:n_token, :n_token]
|
||||||
else:
|
else:
|
||||||
mask = self.mask[:n_token, :n_token]
|
mask = self.mask[:n_token, :n_token]
|
||||||
|
|
||||||
logits = logits + mask
|
logits = logits + mask
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user