language detection patch and test

This commit is contained in:
Kent Slaney 2024-07-22 13:16:53 -07:00
parent 610f82ffba
commit 247391a2af
3 changed files with 104 additions and 10 deletions

View File

@ -4,7 +4,9 @@ import pytest
import torch
import whisper
from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import get_tokenizer
from whisper.transcribe import Transcriber
@pytest.mark.parametrize("model_name", whisper.available_models())
@ -40,3 +42,79 @@ def test_transcribe(model_name: str):
timing_checked = True
assert timing_checked
class MockTokenizer:
def __init__(self, language, **kw):
self.language, self._kw = language, kw
for k, v in kw.items():
setattr(self, k, v)
def encode(self, prompt):
return [self.language, self, prompt]
class OnDemand:
def __init__(self, seq=(), relative=True):
self.seq, self.relative = seq, relative
self.prev, self.given = 0, 0
def __getitem__(self, key):
_key = self.given if self.relative else key
self.prev = (
self.seq[_key]
if _key < len(self.seq)
else int(input(f"lang @ {_key}: ") or self.prev)
)
self.given += 1
return self.prev
def __len__(self):
return CHUNK_LENGTH + 2 if self.relative else len(self.seq)
class TranscriberTest(Transcriber):
sample = object()
dtype = torch.float32
model = type(
"MockModel",
(),
{"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")},
)()
_seek = 0
def __init__(self, seq=None):
super().__init__(self.model, initial_prompt="")
self.seq = OnDemand(seq or ())
self.result = []
self.latest = torch.zeros((0,))
for i in range(len(self.seq)):
self._seek = i
self.frame_offset = max(0, i + 1 - CHUNK_LENGTH)
res = self.initial_prompt_tokens
assert res[0] == self.seq.prev
self.result.append(res[1:])
if seq is None:
print(res)
def detect_language(self, mel=None):
self.result.append([self.sample, mel])
return self.seq[self._seek]
def get_tokenizer(self, multilingual, language, **kw):
return MockTokenizer(language, **{"multilingual": multilingual, **kw})
@property
def rle(self):
res = []
for i, *j in self.result:
if i is self.sample:
res.append(0)
else:
res[-1] += 1
return res
def test_language():
res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle
assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2]

View File

@ -133,7 +133,6 @@ class ArrayStream(AudioSink):
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
)
# https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149
log_spec_bound: Optional[torch.Tensor] = None
def transform(self, stft: torch.Tensor) -> torch.Tensor:

View File

@ -4,6 +4,7 @@ import os
import traceback
import warnings
from dataclasses import dataclass
from math import ceil
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
@ -147,15 +148,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return self._language
self._hypothesis.last = self._seek or 0
self._hypothesis.since += 1
if 2**self._hypothesis.evidence < self._hypothesis.since:
if 2**self._hypothesis.evidence > self._hypothesis.since:
return self._hypothesis.language
self._hypothesis.since = 0
guess = self.detect_language()
if guess == self._hypothesis.language:
self._hypothesis.evidence += 1
self._hypothesis.language = guess
self._hypothesis.evidence = 1
return None
else:
self._hypothesis.language = guess
self._hypothesis.evidence = 0
return guess
@PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
@ -257,18 +259,34 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if self._initial_prompt_tokens is None:
if self.initial_prompt is None:
self._initial_prompt_tokens = []
elif self.language is None:
return []
else:
tokenizer = self.tokenizer
if tokenizer is None:
return []
if tokenizer not in self._initial_prompt_cache:
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
" " + self.initial_prompt.strip()
)
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
if self._tokenizer is not None:
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
return self._initial_prompt_cache[tokenizer]
return self._initial_prompt_tokens
_initial_tokens: int = 0
_initial_finalized: bool = False
_all_tokens: Optional[list] = None
@property
def all_tokens(self):
if self._all_tokens is None:
self._all_tokens = []
if not self._initial_finalized:
initial = self.initial_prompt_tokens
self._all_tokens = initial + self._all_tokens[self._initial_tokens :]
self._initial_tokens = len(initial)
self._initial_finalized = self._initial_prompt_tokens is not None
return self._all_tokens
prompt_reset_since: int = 0
last_speech_timestamp: float = 0.0
frame_offset: int = 0
@ -375,7 +393,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self.hallucination_silence_threshold = hallucination_silence_threshold
self.decode_options = decode_options
self.all_tokens = self.initial_prompt_tokens[:]
self.all_segments = []
def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult:
@ -784,7 +801,7 @@ class ProgressTranscriber(MinimalTranscriber):
n = (
self.latest.shape[-1]
if self.duration is None
else -int(self.duration * -FRAMES_PER_SECOND)
else ceil(self.duration * FRAMES_PER_SECOND)
)
# show the progress bar when verbose is False
# (if True, transcribed text will be printed)