This commit is contained in:
Elijah Melton 2025-01-31 17:02:57 -08:00
parent 564d0889d7
commit 67e68af114
2 changed files with 18 additions and 13 deletions

View File

@ -2,4 +2,4 @@ import whisper
model = whisper.load_model('tiny.en')
print(model.transcribe('test_data/30s/out001.wav')['text'])
print(model.transcribe('test_data/5/out001.wav')['text'])

View File

@ -205,20 +205,25 @@ class AudioEncoder(nn.Module):
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
n_extension = 200
audio_length = int((x.shape[2] + 1) // 2)
pos_emb = torch.concat((
self.positional_embedding[:audio_length + n_extension, :],
self.positional_embedding[-n_extension:, :]), dim=0,
)
FEAT = True
# extend the x's first dimension by n_extension
x = torch.concat((
x[:, :audio_length + n_extension, :],
x[:, -n_extension:, :]), dim=1,
)
if FEAT:
n_extension = 200
audio_length = int((x.shape[2] + 1) // 2)
pos_emb = torch.concat((
self.positional_embedding[:audio_length + n_extension, :],
self.positional_embedding[-n_extension:, :]), dim=0,
)
x = (x + pos_emb).to(x.dtype)
# extend the x's first dimension by n_extension
x = torch.concat((
x[:, :audio_length + n_extension, :],
x[:, -n_extension:, :]), dim=1,
)
x = (x + pos_emb).to(x.dtype)
else:
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)