Merge a84d7ea904359cdf54ec76468fb53217b97ede18 into ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab

This commit is contained in:
Xabier de Zuazo 2024-06-01 07:45:25 +00:00 committed by GitHub
commit dabf80050f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 1 deletions

View File

@ -53,4 +53,4 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"] - run: pip install .["dev"]
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' - run: pytest --durations=0 -vv -k 'not (test_transcribe or test_decode) or test_transcribe[tiny] or test_transcribe[tiny.en] or test_decode[tiny] or test_decode[tiny.en]' -m 'not requires_cuda'

52
tests/test_decode.py Normal file
View File

@ -0,0 +1,52 @@
import os
import pytest
import torch
import whisper
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_decode(model_name: str):
# Regression test: batch_size and beam_size should work together
beam_size = 2
batch_size = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
options = whisper.DecodingOptions(language=language, beam_size=beam_size)
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device)
# Create a small batch
batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1)
results = model.decode(batch_mel, options)
# Since both examples are the same, results should be identical
assert len(results) == batch_size
assert results[0].text == results[1].text
decoded_text = results[0].text.lower()
assert "my fellow americans" in decoded_text
assert "your country" in decoded_text
assert "do for you" in decoded_text
timing_checked = False
if hasattr(results[0], "segments"):
for segment in results[0].segments:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True
if hasattr(results[0], "segments"):
assert timing_checked

View File

@ -731,6 +731,9 @@ class DecodingTask:
] ]
# repeat text tensors by the group size, for beam search or best-of-n sampling # repeat text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to(
audio_features.device
)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop # call the main sampling loop