mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
pass in cut region
This commit is contained in:
parent
e08bd26fce
commit
c29813b560
@ -182,21 +182,22 @@ class ResidualAttentionBlock(nn.Module):
|
||||
class AudioEncoderTokenPruner():
|
||||
def __init__(self, n_extension: int, cut_region: Tuple[int, int]):
|
||||
self.n_extension = n_extension
|
||||
self.cut_region = cut_region
|
||||
|
||||
def prune(self, x: Tensor, positional_embedding: Tensor, cut_region: Tuple[int, int]=[750, 1000]):
|
||||
def prune(self, x: Tensor, positional_embedding: Tensor):
|
||||
# audio_length = int((x.shape[1] + 1) // 2)
|
||||
# [0-950, -----, 1300-1500]
|
||||
|
||||
pos_emb = torch.concat((
|
||||
positional_embedding[:cut_region[0], :],
|
||||
torch.zeros_like(positional_embedding[cut_region[0]:cut_region[1], :], device=x.device),
|
||||
positional_embedding[cut_region[1]:,:]), dim=0,
|
||||
positional_embedding[:self.cut_region[0], :],
|
||||
torch.zeros_like(positional_embedding[self.cut_region[0]:self.cut_region[1], :], device=x.device),
|
||||
positional_embedding[self.cut_region[1]:,:]), dim=0,
|
||||
)
|
||||
|
||||
x = torch.concat((
|
||||
x[:, :cut_region[0], :],
|
||||
torch.zeros_like(x[:, cut_region[0]:cut_region[1], :], device=x.device),
|
||||
x[:, cut_region[1]:,:]), dim=1,
|
||||
x[:, :self.cut_region[0], :],
|
||||
torch.zeros_like(x[:, self.cut_region[0]:self.cut_region[1], :], device=x.device),
|
||||
x[:, self.cut_region[1]:,:]), dim=1,
|
||||
)
|
||||
|
||||
x = (x + pos_emb).to(x.dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user