mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
bad optimization
This commit is contained in:
parent
e9ef1b5772
commit
564d0889d7
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ thumbs.db
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
test_data/
|
||||||
|
|||||||
5
main.py
Normal file
5
main.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import whisper
|
||||||
|
|
||||||
|
model = whisper.load_model('tiny.en')
|
||||||
|
|
||||||
|
print(model.transcribe('test_data/30s/out001.wav')['text'])
|
||||||
@ -204,7 +204,21 @@ class AudioEncoder(nn.Module):
|
|||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
x = (x + self.positional_embedding).to(x.dtype)
|
|
||||||
|
n_extension = 200
|
||||||
|
audio_length = int((x.shape[2] + 1) // 2)
|
||||||
|
pos_emb = torch.concat((
|
||||||
|
self.positional_embedding[:audio_length + n_extension, :],
|
||||||
|
self.positional_embedding[-n_extension:, :]), dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# extend the x's first dimension by n_extension
|
||||||
|
x = torch.concat((
|
||||||
|
x[:, :audio_length + n_extension, :],
|
||||||
|
x[:, -n_extension:, :]), dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = (x + pos_emb).to(x.dtype)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user