Merge d00d4868116b5cacc84efd1427a3a3fb2d534678 into 517a43ecd132a2089d85f4ebc044728a71d49f6e

This commit is contained in:
Eleanor Green 2025-03-11 10:55:04 +00:00 committed by GitHub
commit b5f56f0bcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 27 deletions

View File

@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
def compute_sha256(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
if compute_sha256(download_target) == expected_sha256:
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
if compute_sha256(download_target) != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target
def available_models() -> List[str]:
@ -157,4 +169,4 @@ def load_model(
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
return model.to(device)

View File

@ -224,31 +224,64 @@ class TextDecoder(nn.Module):
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
# Optimisation: pre-compute and register the mask in CUDA if available
if torch.cuda.is_available():
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
"""
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
kv_cache: Optional cache for key/value tensors
Returns:
logits: (n_batch, n_token, n_vocab)
"""
n_batch, n_token = tokens.shape
# Get the dtype of audio_features to ensure consistency
dtype = audio_features.dtype
# Handle kv_cache for token embedding offset
if kv_cache is not None:
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
else:
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
# Convert to the same dtype as audio_features
x = x.to(dtype)
# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()
# Process through attention blocks
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = block(x, audio_features, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
# Ensure consistent dtype for matrix multiplication
# Convert token_embedding weight to the same dtype as x
embedding_weights = self.token_embedding.weight.to(x.dtype)
logits = x @ embedding_weights.T
# Apply mask if not using kv_cache (inference)
if kv_cache is None:
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]
logits = logits + mask
return logits
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()