diff --git a/whisper/model.py b/whisper/model.py index b646f7d..692cba0 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -206,7 +206,7 @@ class AudioEncoderTokenPruner(): class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000] + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None ): super().__init__() self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) @@ -301,6 +301,7 @@ class Whisper(nn.Module): self.dims.n_audio_head, self.dims.n_audio_layer, ext_feat_flag=ext_feat_flag, + cut_region=cut_region, ) self.decoder = TextDecoder( self.dims.n_vocab,