mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Merge d00d4868116b5cacc84efd1427a3a3fb2d534678 into 517a43ecd132a2089d85f4ebc044728a71d49f6e
This commit is contained in:
commit
b5f56f0bcc
@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
def compute_sha256(file_path: str) -> str:
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
if compute_sha256(download_target) == expected_sha256:
|
||||
if in_memory:
|
||||
with open(download_target, "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
if compute_sha256(download_target) != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
if in_memory:
|
||||
with open(download_target, "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
return download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
@ -157,4 +169,4 @@ def load_model(
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
return model.to(device)
|
@ -224,31 +224,64 @@ class TextDecoder(nn.Module):
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
# Optimisation: pre-compute and register the mask in CUDA if available
|
||||
if torch.cuda.is_available():
|
||||
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
||||
|
||||
|
||||
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens: (n_batch, n_token)
|
||||
audio_features: (n_batch, n_audio_ctx, n_audio_state)
|
||||
kv_cache: Optional cache for key/value tensors
|
||||
|
||||
Returns:
|
||||
logits: (n_batch, n_token, n_vocab)
|
||||
"""
|
||||
n_batch, n_token = tokens.shape
|
||||
|
||||
# Get the dtype of audio_features to ensure consistency
|
||||
dtype = audio_features.dtype
|
||||
|
||||
# Handle kv_cache for token embedding offset
|
||||
if kv_cache is not None:
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
|
||||
else:
|
||||
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||
|
||||
# Convert to the same dtype as audio_features
|
||||
x = x.to(dtype)
|
||||
|
||||
# Optimisation: Move audio_features to GPU once here.
|
||||
if torch.cuda.is_available():
|
||||
audio_features = audio_features.cuda()
|
||||
|
||||
# Process through attention blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
x = block(x, audio_features, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
# Ensure consistent dtype for matrix multiplication
|
||||
# Convert token_embedding weight to the same dtype as x
|
||||
embedding_weights = self.token_embedding.weight.to(x.dtype)
|
||||
logits = x @ embedding_weights.T
|
||||
|
||||
# Apply mask if not using kv_cache (inference)
|
||||
if kv_cache is None:
|
||||
# Optimisation: Apply the precomputed CUDA mask if available.
|
||||
if torch.cuda.is_available():
|
||||
mask = self.mask_cuda[:n_token, :n_token]
|
||||
else:
|
||||
mask = self.mask[:n_token, :n_token]
|
||||
|
||||
logits = logits + mask
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
|
Loading…
x
Reference in New Issue
Block a user