mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
FEAT
This commit is contained in:
parent
564d0889d7
commit
67e68af114
2
main.py
2
main.py
@ -2,4 +2,4 @@ import whisper
|
||||
|
||||
model = whisper.load_model('tiny.en')
|
||||
|
||||
print(model.transcribe('test_data/30s/out001.wav')['text'])
|
||||
print(model.transcribe('test_data/5/out001.wav')['text'])
|
||||
|
@ -205,20 +205,25 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
|
||||
n_extension = 200
|
||||
audio_length = int((x.shape[2] + 1) // 2)
|
||||
pos_emb = torch.concat((
|
||||
self.positional_embedding[:audio_length + n_extension, :],
|
||||
self.positional_embedding[-n_extension:, :]), dim=0,
|
||||
)
|
||||
FEAT = True
|
||||
|
||||
# extend the x's first dimension by n_extension
|
||||
x = torch.concat((
|
||||
x[:, :audio_length + n_extension, :],
|
||||
x[:, -n_extension:, :]), dim=1,
|
||||
)
|
||||
if FEAT:
|
||||
n_extension = 200
|
||||
audio_length = int((x.shape[2] + 1) // 2)
|
||||
pos_emb = torch.concat((
|
||||
self.positional_embedding[:audio_length + n_extension, :],
|
||||
self.positional_embedding[-n_extension:, :]), dim=0,
|
||||
)
|
||||
|
||||
x = (x + pos_emb).to(x.dtype)
|
||||
# extend the x's first dimension by n_extension
|
||||
x = torch.concat((
|
||||
x[:, :audio_length + n_extension, :],
|
||||
x[:, -n_extension:, :]), dim=1,
|
||||
)
|
||||
|
||||
x = (x + pos_emb).to(x.dtype)
|
||||
else:
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user