mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge d00d4868116b5cacc84efd1427a3a3fb2d534678 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
17f43117f7
@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
def compute_sha256(file_path: str) -> str:
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
sha256.update(chunk)
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
if os.path.isfile(download_target):
|
if os.path.isfile(download_target):
|
||||||
with open(download_target, "rb") as f:
|
if compute_sha256(download_target) == expected_sha256:
|
||||||
model_bytes = f.read()
|
if in_memory:
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
with open(download_target, "rb") as f:
|
||||||
return model_bytes if in_memory else download_target
|
return f.read()
|
||||||
|
else:
|
||||||
|
return download_target
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
output.write(buffer)
|
output.write(buffer)
|
||||||
loop.update(len(buffer))
|
loop.update(len(buffer))
|
||||||
|
|
||||||
model_bytes = open(download_target, "rb").read()
|
if compute_sha256(download_target) != expected_sha256:
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_bytes if in_memory else download_target
|
if in_memory:
|
||||||
|
with open(download_target, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
else:
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
|
||||||
def available_models() -> List[str]:
|
def available_models() -> List[str]:
|
||||||
|
|||||||
@ -224,31 +224,64 @@ class TextDecoder(nn.Module):
|
|||||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
self.register_buffer("mask", mask, persistent=False)
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
# Optimisation: pre-compute and register the mask in CUDA if available
|
||||||
"""
|
if torch.cuda.is_available():
|
||||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
||||||
the text tokens
|
|
||||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
|
||||||
the encoded audio features to be attended on
|
|
||||||
"""
|
|
||||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
||||||
x = (
|
|
||||||
self.token_embedding(x)
|
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
|
||||||
)
|
|
||||||
x = x.to(xa.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: (n_batch, n_token)
|
||||||
|
audio_features: (n_batch, n_audio_ctx, n_audio_state)
|
||||||
|
kv_cache: Optional cache for key/value tensors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits: (n_batch, n_token, n_vocab)
|
||||||
|
"""
|
||||||
|
n_batch, n_token = tokens.shape
|
||||||
|
|
||||||
|
# Get the dtype of audio_features to ensure consistency
|
||||||
|
dtype = audio_features.dtype
|
||||||
|
|
||||||
|
# Handle kv_cache for token embedding offset
|
||||||
|
if kv_cache is not None:
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
|
||||||
|
else:
|
||||||
|
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
|
||||||
|
|
||||||
|
# Convert to the same dtype as audio_features
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
|
# Optimisation: Move audio_features to GPU once here.
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
audio_features = audio_features.cuda()
|
||||||
|
|
||||||
|
# Process through attention blocks
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
x = block(x, audio_features, kv_cache=kv_cache)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
logits = (
|
|
||||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
# Ensure consistent dtype for matrix multiplication
|
||||||
).float()
|
# Convert token_embedding weight to the same dtype as x
|
||||||
|
embedding_weights = self.token_embedding.weight.to(x.dtype)
|
||||||
|
logits = x @ embedding_weights.T
|
||||||
|
|
||||||
|
# Apply mask if not using kv_cache (inference)
|
||||||
|
if kv_cache is None:
|
||||||
|
# Optimisation: Apply the precomputed CUDA mask if available.
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
mask = self.mask_cuda[:n_token, :n_token]
|
||||||
|
else:
|
||||||
|
mask = self.mask[:n_token, :n_token]
|
||||||
|
|
||||||
|
logits = logits + mask
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user