From c29813b560c797bcc85a53eede5a2a3511e4d04a Mon Sep 17 00:00:00 2001 From: Amal Jacob Date: Fri, 7 Feb 2025 13:13:50 -0800 Subject: [PATCH] pass in cut region --- whisper/model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 3d8d51e..b646f7d 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -182,21 +182,22 @@ class ResidualAttentionBlock(nn.Module): class AudioEncoderTokenPruner(): def __init__(self, n_extension: int, cut_region: Tuple[int, int]): self.n_extension = n_extension + self.cut_region = cut_region - def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]): + def prune(self, x: Tensor, positional_embedding: Tensor): # audio_length = int((x.shape[1] + 1) // 2) # [0-950, -----, 1300-1500] pos_emb = torch.concat(( - positional_embedding[:cut_region[0], :], - torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device), - positional_embedding[cut_region[1]:,:]), dim=0, + positional_embedding[:self.cut_region[0], :], + torch.zeros_like(positional_embedding[self.cut_region[0]:self.cut_region[1], :], device=x.device), + positional_embedding[self.cut_region[1]:,:]), dim=0, ) x = torch.concat(( - x[:, :cut_region[0], :], - torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device), - x[:, cut_region[1]:,:]), dim=1, + x[:, :self.cut_region[0], :], + torch.zeros_like(x[:, self.cut_region[0]:self.cut_region[1], :], device=x.device), + x[:, self.cut_region[1]:,:]), dim=1, ) x = (x + pos_emb).to(x.dtype)