Merge d00d4868116b5cacc84efd1427a3a3fb2d534678 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Eleanor Green 2025-06-25 20:17:44 -05:00 committed by GitHub
commit 17f43117f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 27 deletions

View File

@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.exists(download_target) and not os.path.isfile(download_target): if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file") raise RuntimeError(f"{download_target} exists and is not a regular file")
def compute_sha256(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
if os.path.isfile(download_target): if os.path.isfile(download_target):
with open(download_target, "rb") as f: if compute_sha256(download_target) == expected_sha256:
model_bytes = f.read() if in_memory:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: with open(download_target, "rb") as f:
return model_bytes if in_memory else download_target return f.read()
else:
return download_target
else: else:
warnings.warn( warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
model_bytes = open(download_target, "rb").read() if compute_sha256(download_target) != expected_sha256:
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError( raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
) )
return model_bytes if in_memory else download_target if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target
def available_models() -> List[str]: def available_models() -> List[str]:

View File

@ -224,31 +224,64 @@ class TextDecoder(nn.Module):
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): # Optimisation: pre-compute and register the mask in CUDA if available
""" if torch.cuda.is_available():
x : torch.LongTensor, shape = (batch_size, <= n_ctx) self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
"""
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
kv_cache: Optional cache for key/value tensors
Returns:
logits: (n_batch, n_token, n_vocab)
"""
n_batch, n_token = tokens.shape
# Get the dtype of audio_features to ensure consistency
dtype = audio_features.dtype
# Handle kv_cache for token embedding offset
if kv_cache is not None:
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
else:
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
# Convert to the same dtype as audio_features
x = x.to(dtype)
# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()
# Process through attention blocks
for block in self.blocks: for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = block(x, audio_features, kv_cache=kv_cache)
x = self.ln(x) x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) # Ensure consistent dtype for matrix multiplication
).float() # Convert token_embedding weight to the same dtype as x
embedding_weights = self.token_embedding.weight.to(x.dtype)
logits = x @ embedding_weights.T
# Apply mask if not using kv_cache (inference)
if kv_cache is None:
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]
logits = logits + mask
return logits return logits
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions):
super().__init__() super().__init__()