mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
need to pass in cut region to audio encoder too
This commit is contained in:
parent
c29813b560
commit
97bbe4b70f
@ -206,7 +206,7 @@ class AudioEncoderTokenPruner():
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000]
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
@ -301,6 +301,7 @@ class Whisper(nn.Module):
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
ext_feat_flag=ext_feat_flag,
|
||||
cut_region=cut_region,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
|
Loading…
x
Reference in New Issue
Block a user