mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Current `transcribe` API only outputs the progress and transcribed texts on stdout. Callers can only access the result after the whole transcription is done, and they need to hijack `tqdm` interface to get the realtime transcription progress. This commit adds a simple interface that can be passed as a parameter in `transcribe` so the API users don't need to fallback to above hacks or low-level APIs for this need. Signed-off-by: Austin Chang <austin880625@gmail.com>
62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import whisper
|
|
from whisper.tokenizer import get_tokenizer
|
|
|
|
|
|
class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
|
|
def start(self, total: int):
|
|
self.result = ""
|
|
self.total = total
|
|
self.progress = 0
|
|
return self
|
|
def update_line(self, start: float, end: float, text: str):
|
|
self.result += text
|
|
def update(self, n):
|
|
self.progress += n
|
|
def get_result(self):
|
|
return self.result
|
|
def verify_total(self):
|
|
return self.total == self.progress
|
|
|
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
|
def test_transcribe(model_name: str):
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = whisper.load_model(model_name).to(device)
|
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
|
receiver = TestingProgressReceiver()
|
|
|
|
language = "en" if model_name.endswith(".en") else None
|
|
result = model.transcribe(
|
|
audio_path, language=language, temperature=0.0, word_timestamps=True,
|
|
progress_receiver=receiver
|
|
)
|
|
assert receiver.verify_total()
|
|
assert result["language"] == "en"
|
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
|
assert result["text"] == receiver.get_result()
|
|
|
|
transcription = result["text"].lower()
|
|
assert "my fellow americans" in transcription
|
|
assert "your country" in transcription
|
|
assert "do for you" in transcription
|
|
|
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
|
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
|
assert tokenizer.decode(all_tokens) == result["text"]
|
|
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
|
|
|
timing_checked = False
|
|
for segment in result["segments"]:
|
|
for timing in segment["words"]:
|
|
assert timing["start"] < timing["end"]
|
|
if timing["word"].strip(" ,") == "Americans":
|
|
assert timing["start"] <= 1.8
|
|
assert timing["end"] >= 1.8
|
|
timing_checked = True
|
|
|
|
assert timing_checked
|