whisper/tests/test_decode.py
Xabier de Zuazo a84d7ea904 Fix beam search with batch processing in Whisper decoding
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR #1483.
2024-06-01 09:37:32 +02:00

53 lines
1.6 KiB
Python

import os
import pytest
import torch
import whisper
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_decode(model_name: str):
# Regression test: batch_size and beam_size should work together
beam_size = 2
batch_size = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
options = whisper.DecodingOptions(language=language, beam_size=beam_size)
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device)
# Create a small batch
batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1)
results = model.decode(batch_mel, options)
# Since both examples are the same, results should be identical
assert len(results) == batch_size
assert results[0].text == results[1].text
decoded_text = results[0].text.lower()
assert "my fellow americans" in decoded_text
assert "your country" in decoded_text
assert "do for you" in decoded_text
timing_checked = False
if hasattr(results[0], "segments"):
for segment in results[0].segments:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True
if hasattr(results[0], "segments"):
assert timing_checked