From 7a552cb5cc047744b3f991b757f95743e21de738 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 11 Mar 2025 10:42:32 +0000 Subject: [PATCH] Fix indentation --- whisper/model.py | 63 +++++++++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 612e6b9..5a247e3 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -228,7 +228,6 @@ class TextDecoder(nn.Module): if torch.cuda.is_available(): self.register_buffer("mask_cuda", mask.cuda(), persistent=False) -<<<<<<< Updated upstream def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: """ @@ -267,14 +266,6 @@ class TextDecoder(nn.Module): def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): -======= - def forward( - self, - tokens: Tensor, - audio_features: Tensor, - kv_cache: Optional[dict] = None - ) -> Tensor: ->>>>>>> Stashed changes """ Args: tokens: (n_batch, n_token) or x tensor @@ -322,31 +313,33 @@ class TextDecoder(nn.Module): return logits - class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): - super().__init__() - self.dims = dims - self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, - ) - self.decoder = TextDecoder( - self.dims.n_vocab, - self.dims.n_text_ctx, - self.dims.n_text_state, - self.dims.n_text_head, - self.dims.n_text_layer, - ) - # use the last half among the decoder layers for time alignment by default; - # to use a specific set of heads, see `set_alignment_heads()` below. - all_heads = torch.zeros( - self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - ) - all_heads[self.dims.n_text_layer // 2 :] = True - self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + +# The Whisper class has been moved outside of TextDecoder and is now a top-level class +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) def set_alignment_heads(self, dump: bytes): array = np.frombuffer( @@ -415,4 +408,4 @@ class TextDecoder(nn.Module): detect_language = detect_language_function transcribe = transcribe_function - decode = decode_function \ No newline at end of file + decode = decode_function