mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import whisper
|
|
from whisper.audio import CHUNK_LENGTH
|
|
from whisper.tokenizer import get_tokenizer
|
|
from whisper.transcribe import Transcriber
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
|
def test_transcribe(model_name: str):
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = whisper.load_model(model_name).to(device)
|
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
|
|
|
language = "en" if model_name.endswith(".en") else None
|
|
result = model.transcribe(
|
|
audio_path, language=language, temperature=0.0, word_timestamps=True
|
|
)
|
|
assert result["language"] == "en"
|
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
|
|
|
transcription = result["text"].lower()
|
|
assert "my fellow americans" in transcription
|
|
assert "your country" in transcription
|
|
assert "do for you" in transcription
|
|
|
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
|
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
|
assert tokenizer.decode(all_tokens) == result["text"]
|
|
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
|
|
|
timing_checked = False
|
|
for segment in result["segments"]:
|
|
for timing in segment["words"]:
|
|
assert timing["start"] < timing["end"]
|
|
if timing["word"].strip(" ,") == "Americans":
|
|
assert timing["start"] <= 1.8
|
|
assert timing["end"] >= 1.8
|
|
timing_checked = True
|
|
|
|
assert timing_checked
|
|
|
|
|
|
class MockTokenizer:
|
|
def __init__(self, language, **kw):
|
|
self.language, self._kw = language, kw
|
|
for k, v in kw.items():
|
|
setattr(self, k, v)
|
|
|
|
def encode(self, prompt):
|
|
return [self.language, self, prompt]
|
|
|
|
|
|
class OnDemand:
|
|
def __init__(self, seq=(), relative=True):
|
|
self.seq, self.relative = seq, relative
|
|
self.prev, self.given = 0, 0
|
|
|
|
def __getitem__(self, key):
|
|
_key = self.given if self.relative else key
|
|
self.prev = (
|
|
self.seq[_key]
|
|
if _key < len(self.seq)
|
|
else int(input(f"lang @ {_key}: ") or self.prev)
|
|
)
|
|
self.given += 1
|
|
return self.prev
|
|
|
|
def __len__(self):
|
|
return CHUNK_LENGTH + 2 if self.relative else len(self.seq)
|
|
|
|
|
|
class TranscriberTest(Transcriber):
|
|
sample = object()
|
|
dtype = torch.float32
|
|
model = type(
|
|
"MockModel",
|
|
(),
|
|
{"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")},
|
|
)()
|
|
_seek = 0
|
|
|
|
def __init__(self, seq=None):
|
|
super().__init__(self.model, initial_prompt="")
|
|
self.seq = OnDemand(seq or ())
|
|
self.result = []
|
|
self.latest = torch.zeros((0,))
|
|
for i in range(len(self.seq)):
|
|
self._seek = i
|
|
self.frame_offset = max(0, i + 1 - CHUNK_LENGTH)
|
|
res = self.initial_prompt_tokens
|
|
assert res[0] == self.seq.prev
|
|
self.result.append(res[1:])
|
|
if seq is None:
|
|
print(res)
|
|
|
|
def detect_language(self, mel=None):
|
|
self.result.append([self.sample, mel])
|
|
return self.seq[self._seek]
|
|
|
|
def get_tokenizer(self, multilingual, language, **kw):
|
|
return MockTokenizer(language, **{"multilingual": multilingual, **kw})
|
|
|
|
@property
|
|
def rle(self):
|
|
res = []
|
|
for i, *j in self.result:
|
|
if i is self.sample:
|
|
res.append(0)
|
|
else:
|
|
res[-1] += 1
|
|
return res
|
|
|
|
|
|
def test_language():
|
|
res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle
|
|
assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2]
|