need to pass in cut region to audio encoder too

This commit is contained in:
Amal Jacob 2025-02-07 13:42:03 -08:00
parent c29813b560
commit 97bbe4b70f

View File

@ -206,7 +206,7 @@ class AudioEncoderTokenPruner():
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=[750, 1000]
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, ext_feat_flag: bool = False, cut_region: Tuple[int, int]=None
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -301,6 +301,7 @@ class Whisper(nn.Module):
self.dims.n_audio_head,
self.dims.n_audio_layer,
ext_feat_flag=ext_feat_flag,
cut_region=cut_region,
)
self.decoder = TextDecoder(
self.dims.n_vocab,