mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Fix indentation
This commit is contained in:
parent
52649452a8
commit
7a552cb5cc
@ -228,7 +228,6 @@ class TextDecoder(nn.Module):
|
||||
if torch.cuda.is_available():
|
||||
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
|
||||
|
||||
<<<<<<< Updated upstream
|
||||
|
||||
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -267,14 +266,6 @@ class TextDecoder(nn.Module):
|
||||
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
=======
|
||||
def forward(
|
||||
self,
|
||||
tokens: Tensor,
|
||||
audio_features: Tensor,
|
||||
kv_cache: Optional[dict] = None
|
||||
) -> Tensor:
|
||||
>>>>>>> Stashed changes
|
||||
"""
|
||||
Args:
|
||||
tokens: (n_batch, n_token) or x tensor
|
||||
@ -322,31 +313,33 @@ class TextDecoder(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
@ -415,4 +408,4 @@ class TextDecoder(nn.Module):
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
decode = decode_function
|
||||
|
Loading…
x
Reference in New Issue
Block a user