bad optimization

This commit is contained in:
Elijah Melton 2025-01-31 17:01:37 -08:00
parent e9ef1b5772
commit 564d0889d7
3 changed files with 21 additions and 1 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ thumbs.db
.DS_Store .DS_Store
.idea .idea
test_data/

5
main.py Normal file
View File

@ -0,0 +1,5 @@
import whisper
model = whisper.load_model('tiny.en')
print(model.transcribe('test_data/30s/out001.wav')['text'])

View File

@ -204,7 +204,21 @@ class AudioEncoder(nn.Module):
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
n_extension = 200
audio_length = int((x.shape[2] + 1) // 2)
pos_emb = torch.concat((
self.positional_embedding[:audio_length + n_extension, :],
self.positional_embedding[-n_extension:, :]), dim=0,
)
# extend the x's first dimension by n_extension
x = torch.concat((
x[:, :audio_length + n_extension, :],
x[:, -n_extension:, :]), dim=1,
)
x = (x + pos_emb).to(x.dtype)
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)