pass in cut region

This commit is contained in:
Amal Jacob 2025-02-07 13:13:50 -08:00
parent e08bd26fce
commit c29813b560

View File

@ -182,21 +182,22 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoderTokenPruner():
def __init__(self, n_extension: int, cut_region: Tuple[int, int]):
self.n_extension = n_extension
self.cut_region = cut_region
def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]):
def prune(self, x: Tensor, positional_embedding: Tensor):
# audio_length = int((x.shape[1] + 1) // 2)
# [0-950, -----, 1300-1500]
pos_emb = torch.concat((
positional_embedding[:cut_region[0], :],
torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device),
positional_embedding[cut_region[1]:,:]), dim=0,
positional_embedding[:self.cut_region[0], :],
torch.zeros_like(positional_embedding[self.cut_region[0]:self.cut_region[1], :], device=x.device),
positional_embedding[self.cut_region[1]:,:]), dim=0,
)
x = torch.concat((
x[:, :cut_region[0], :],
torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device),
x[:, cut_region[1]:,:]), dim=1,
x[:, :self.cut_region[0], :],
torch.zeros_like(x[:, self.cut_region[0]:self.cut_region[1], :], device=x.device),
x[:, self.cut_region[1]:,:]), dim=1,
)
x = (x + pos_emb).to(x.dtype)