mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 22:45:52 +00:00
Merge a84d7ea904359cdf54ec76468fb53217b97ede18 into ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab
This commit is contained in:
commit
dabf80050f
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -53,4 +53,4 @@ jobs:
|
|||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||||
- run: pip install .["dev"]
|
- run: pip install .["dev"]
|
||||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
- run: pytest --durations=0 -vv -k 'not (test_transcribe or test_decode) or test_transcribe[tiny] or test_transcribe[tiny.en] or test_decode[tiny] or test_decode[tiny.en]' -m 'not requires_cuda'
|
||||||
|
|||||||
52
tests/test_decode.py
Normal file
52
tests/test_decode.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
|
def test_decode(model_name: str):
|
||||||
|
# Regression test: batch_size and beam_size should work together
|
||||||
|
beam_size = 2
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = whisper.load_model(model_name).to(device)
|
||||||
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||||
|
|
||||||
|
language = "en" if model_name.endswith(".en") else None
|
||||||
|
|
||||||
|
options = whisper.DecodingOptions(language=language, beam_size=beam_size)
|
||||||
|
|
||||||
|
audio = whisper.load_audio(audio_path)
|
||||||
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
mel = whisper.log_mel_spectrogram(audio).to(device)
|
||||||
|
|
||||||
|
# Create a small batch
|
||||||
|
batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
results = model.decode(batch_mel, options)
|
||||||
|
|
||||||
|
# Since both examples are the same, results should be identical
|
||||||
|
assert len(results) == batch_size
|
||||||
|
assert results[0].text == results[1].text
|
||||||
|
|
||||||
|
decoded_text = results[0].text.lower()
|
||||||
|
assert "my fellow americans" in decoded_text
|
||||||
|
assert "your country" in decoded_text
|
||||||
|
assert "do for you" in decoded_text
|
||||||
|
|
||||||
|
timing_checked = False
|
||||||
|
if hasattr(results[0], "segments"):
|
||||||
|
for segment in results[0].segments:
|
||||||
|
for timing in segment["words"]:
|
||||||
|
assert timing["start"] < timing["end"]
|
||||||
|
if timing["word"].strip(" ,") == "Americans":
|
||||||
|
assert timing["start"] <= 1.8
|
||||||
|
assert timing["end"] >= 1.8
|
||||||
|
timing_checked = True
|
||||||
|
|
||||||
|
if hasattr(results[0], "segments"):
|
||||||
|
assert timing_checked
|
||||||
@ -731,6 +731,9 @@ class DecodingTask:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||||
|
audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to(
|
||||||
|
audio_features.device
|
||||||
|
)
|
||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user