Fix indentation

This commit is contained in:
eleanorTurintech 2025-03-11 10:42:32 +00:00
parent 52649452a8
commit 7a552cb5cc

View File

@ -228,7 +228,6 @@ class TextDecoder(nn.Module):
if torch.cuda.is_available():
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
<<<<<<< Updated upstream
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""
@ -267,14 +266,6 @@ class TextDecoder(nn.Module):
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
=======
def forward(
self,
tokens: Tensor,
audio_features: Tensor,
kv_cache: Optional[dict] = None
) -> Tensor:
>>>>>>> Stashed changes
"""
Args:
tokens: (n_batch, n_token) or x tensor
@ -322,31 +313,33 @@ class TextDecoder(nn.Module):
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
@ -415,4 +408,4 @@ class TextDecoder(nn.Module):
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
decode = decode_function