From 52649452a87747718d52dada906c42f0b0bab0e1 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Mon, 3 Feb 2025 09:49:17 +0000 Subject: [PATCH 1/3] Peformance improvements --- whisper/__init__.py | 28 +++++--- whisper/model.py | 155 ++++++++++++++++++++++++++++++++------------ 2 files changed, 134 insertions(+), 49 deletions(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718..bfe0a89 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") + def compute_sha256(file_path: str) -> str: + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes if in_memory else download_target + if compute_sha256(download_target) == expected_sha256: + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" @@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: output.write(buffer) loop.update(len(buffer)) - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + if compute_sha256(download_target) != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) - return model_bytes if in_memory else download_target + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target def available_models() -> List[str]: @@ -157,4 +169,4 @@ def load_model( if alignment_heads is not None: model.set_alignment_heads(alignment_heads) - return model.to(device) + return model.to(device) \ No newline at end of file diff --git a/whisper/model.py b/whisper/model.py index e537447..612e6b9 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -224,56 +224,129 @@ class TextDecoder(nn.Module): mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + # Optimisation: pre-compute and register the mask in CUDA if available + if torch.cuda.is_available(): + self.register_buffer("mask_cuda", mask.cuda(), persistent=False) + +<<<<<<< Updated upstream + + def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: """ - x : torch.LongTensor, shape = (batch_size, <= n_ctx) - the text tokens - xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) - the encoded audio features to be attended on + Args: + tokens: (n_batch, n_token) + audio_features: (n_batch, n_audio_ctx, n_audio_state) + + Returns: + logits: (n_batch, n_token, n_vocab) """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) - x = x.to(xa.dtype) + n_batch, n_token = tokens.shape + n_audio_ctx, n_audio_state = audio_features.shape[1:] + + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Optimisation: Move audio_features to GPU once here. + if torch.cuda.is_available(): + audio_features = audio_features.cuda() + for block in self.blocks: - x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + x = block(x, audio_features) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + logits = x @ self.token_embedding.weight.T + + # Optimisation: Apply the precomputed CUDA mask if available. + if torch.cuda.is_available(): + mask = self.mask_cuda[:n_token, :n_token] + else: + mask = self.mask[:n_token, :n_token] + + logits = logits + mask return logits -class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): - super().__init__() - self.dims = dims - self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, - ) - self.decoder = TextDecoder( - self.dims.n_vocab, - self.dims.n_text_ctx, - self.dims.n_text_state, - self.dims.n_text_head, - self.dims.n_text_layer, - ) - # use the last half among the decoder layers for time alignment by default; - # to use a specific set of heads, see `set_alignment_heads()` below. - all_heads = torch.zeros( - self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - ) - all_heads[self.dims.n_text_layer // 2 :] = True - self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): +======= + def forward( + self, + tokens: Tensor, + audio_features: Tensor, + kv_cache: Optional[dict] = None + ) -> Tensor: +>>>>>>> Stashed changes + """ + Args: + tokens: (n_batch, n_token) or x tensor + audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor + kv_cache: Optional cache for key/value tensors + """ + if kv_cache is not None: + # Handle the kv_cache case + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(tokens) + + self.positional_embedding[offset : offset + tokens.shape[-1]] + ) + x = x.to(audio_features.dtype) + + for block in self.blocks: + x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + else: + # Handle the non-kv_cache case + n_batch, n_token = tokens.shape + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + if torch.cuda.is_available(): + audio_features = audio_features.cuda() + + for block in self.blocks: + x = block(x, audio_features) + + x = self.ln(x) + logits = x @ self.token_embedding.weight.T + + if torch.cuda.is_available(): + mask = self.mask_cuda[:n_token, :n_token] + else: + mask = self.mask[:n_token, :n_token] + + logits = logits + mask + + return logits + + class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) def set_alignment_heads(self, dump: bytes): array = np.frombuffer( @@ -342,4 +415,4 @@ class Whisper(nn.Module): detect_language = detect_language_function transcribe = transcribe_function - decode = decode_function + decode = decode_function \ No newline at end of file From 7a552cb5cc047744b3f991b757f95743e21de738 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 11 Mar 2025 10:42:32 +0000 Subject: [PATCH 2/3] Fix indentation --- whisper/model.py | 63 +++++++++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 612e6b9..5a247e3 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -228,7 +228,6 @@ class TextDecoder(nn.Module): if torch.cuda.is_available(): self.register_buffer("mask_cuda", mask.cuda(), persistent=False) -<<<<<<< Updated upstream def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: """ @@ -267,14 +266,6 @@ class TextDecoder(nn.Module): def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): -======= - def forward( - self, - tokens: Tensor, - audio_features: Tensor, - kv_cache: Optional[dict] = None - ) -> Tensor: ->>>>>>> Stashed changes """ Args: tokens: (n_batch, n_token) or x tensor @@ -322,31 +313,33 @@ class TextDecoder(nn.Module): return logits - class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): - super().__init__() - self.dims = dims - self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, - ) - self.decoder = TextDecoder( - self.dims.n_vocab, - self.dims.n_text_ctx, - self.dims.n_text_state, - self.dims.n_text_head, - self.dims.n_text_layer, - ) - # use the last half among the decoder layers for time alignment by default; - # to use a specific set of heads, see `set_alignment_heads()` below. - all_heads = torch.zeros( - self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - ) - all_heads[self.dims.n_text_layer // 2 :] = True - self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + +# The Whisper class has been moved outside of TextDecoder and is now a top-level class +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) def set_alignment_heads(self, dump: bytes): array = np.frombuffer( @@ -415,4 +408,4 @@ class TextDecoder(nn.Module): detect_language = detect_language_function transcribe = transcribe_function - decode = decode_function \ No newline at end of file + decode = decode_function From d00d4868116b5cacc84efd1427a3a3fb2d534678 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 11 Mar 2025 10:45:40 +0000 Subject: [PATCH 3/3] Fix forward method overload in TextDecoder --- whisper/model.py | 87 +++++++++++++++--------------------------------- 1 file changed, 27 insertions(+), 60 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 5a247e3..072f0a4 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -229,91 +229,58 @@ class TextDecoder(nn.Module): self.register_buffer("mask_cuda", mask.cuda(), persistent=False) - def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor: """ Args: tokens: (n_batch, n_token) audio_features: (n_batch, n_audio_ctx, n_audio_state) + kv_cache: Optional cache for key/value tensors Returns: logits: (n_batch, n_token, n_vocab) """ n_batch, n_token = tokens.shape - n_audio_ctx, n_audio_state = audio_features.shape[1:] - - x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Get the dtype of audio_features to ensure consistency + dtype = audio_features.dtype + + # Handle kv_cache for token embedding offset + if kv_cache is not None: + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]] + else: + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Convert to the same dtype as audio_features + x = x.to(dtype) # Optimisation: Move audio_features to GPU once here. if torch.cuda.is_available(): audio_features = audio_features.cuda() - + # Process through attention blocks for block in self.blocks: - x = block(x, audio_features) + x = block(x, audio_features, kv_cache=kv_cache) x = self.ln(x) - logits = x @ self.token_embedding.weight.T - - # Optimisation: Apply the precomputed CUDA mask if available. - if torch.cuda.is_available(): - mask = self.mask_cuda[:n_token, :n_token] - else: - mask = self.mask[:n_token, :n_token] - logits = logits + mask - - return logits - - - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): - """ - Args: - tokens: (n_batch, n_token) or x tensor - audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor - kv_cache: Optional cache for key/value tensors - """ - if kv_cache is not None: - # Handle the kv_cache case - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(tokens) - + self.positional_embedding[offset : offset + tokens.shape[-1]] - ) - x = x.to(audio_features.dtype) - - for block in self.blocks: - x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache) - - x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() - - return logits - else: - # Handle the non-kv_cache case - n_batch, n_token = tokens.shape - x = self.token_embedding(tokens) + self.positional_embedding[:n_token] - - if torch.cuda.is_available(): - audio_features = audio_features.cuda() - - for block in self.blocks: - x = block(x, audio_features) - - x = self.ln(x) - logits = x @ self.token_embedding.weight.T + # Ensure consistent dtype for matrix multiplication + # Convert token_embedding weight to the same dtype as x + embedding_weights = self.token_embedding.weight.to(x.dtype) + logits = x @ embedding_weights.T + # Apply mask if not using kv_cache (inference) + if kv_cache is None: + # Optimisation: Apply the precomputed CUDA mask if available. if torch.cuda.is_available(): mask = self.mask_cuda[:n_token, :n_token] else: mask = self.mask[:n_token, :n_token] - + logits = logits + mask - return logits - - + return logits + # The Whisper class has been moved outside of TextDecoder and is now a top-level class class Whisper(nn.Module): def __init__(self, dims: ModelDimensions):