diff --git a/.gitignore b/.gitignore index 7ae8fab..24f6669 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ thumbs.db .DS_Store .idea +test_data/ diff --git a/main.py b/main.py new file mode 100644 index 0000000..5f14500 --- /dev/null +++ b/main.py @@ -0,0 +1,5 @@ +import whisper + +model = whisper.load_model('tiny.en') + +print(model.transcribe('test_data/30s/out001.wav')['text']) diff --git a/whisper/model.py b/whisper/model.py index 00404d2..6a7fcb8 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -204,7 +204,21 @@ class AudioEncoder(nn.Module): x = x.permute(0, 2, 1) assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" - x = (x + self.positional_embedding).to(x.dtype) + + n_extension = 200 + audio_length = int((x.shape[2] + 1) // 2) + pos_emb = torch.concat(( + self.positional_embedding[:audio_length + n_extension, :], + self.positional_embedding[-n_extension:, :]), dim=0, + ) + + # extend the x's first dimension by n_extension + x = torch.concat(( + x[:, :audio_length + n_extension, :], + x[:, -n_extension:, :]), dim=1, + ) + + x = (x + pos_emb).to(x.dtype) for block in self.blocks: x = block(x)