mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
bad optimization
This commit is contained in:
parent
e9ef1b5772
commit
564d0889d7
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ thumbs.db
|
||||
.DS_Store
|
||||
.idea
|
||||
|
||||
test_data/
|
||||
|
||||
5
main.py
Normal file
5
main.py
Normal file
@ -0,0 +1,5 @@
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model('tiny.en')
|
||||
|
||||
print(model.transcribe('test_data/30s/out001.wav')['text'])
|
||||
@ -204,7 +204,21 @@ class AudioEncoder(nn.Module):
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
n_extension = 200
|
||||
audio_length = int((x.shape[2] + 1) // 2)
|
||||
pos_emb = torch.concat((
|
||||
self.positional_embedding[:audio_length + n_extension, :],
|
||||
self.positional_embedding[-n_extension:, :]), dim=0,
|
||||
)
|
||||
|
||||
# extend the x's first dimension by n_extension
|
||||
x = torch.concat((
|
||||
x[:, :audio_length + n_extension, :],
|
||||
x[:, -n_extension:, :]), dim=1,
|
||||
)
|
||||
|
||||
x = (x + pos_emb).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user