diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 599221a..39108d7 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -4,7 +4,9 @@ import pytest import torch import whisper +from whisper.audio import CHUNK_LENGTH from whisper.tokenizer import get_tokenizer +from whisper.transcribe import Transcriber @pytest.mark.parametrize("model_name", whisper.available_models()) @@ -40,3 +42,79 @@ def test_transcribe(model_name: str): timing_checked = True assert timing_checked + + +class MockTokenizer: + def __init__(self, language, **kw): + self.language, self._kw = language, kw + for k, v in kw.items(): + setattr(self, k, v) + + def encode(self, prompt): + return [self.language, self, prompt] + + +class OnDemand: + def __init__(self, seq=(), relative=True): + self.seq, self.relative = seq, relative + self.prev, self.given = 0, 0 + + def __getitem__(self, key): + _key = self.given if self.relative else key + self.prev = ( + self.seq[_key] + if _key < len(self.seq) + else int(input(f"lang @ {_key}: ") or self.prev) + ) + self.given += 1 + return self.prev + + def __len__(self): + return CHUNK_LENGTH + 2 if self.relative else len(self.seq) + + +class TranscriberTest(Transcriber): + sample = object() + dtype = torch.float32 + model = type( + "MockModel", + (), + {"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")}, + )() + _seek = 0 + + def __init__(self, seq=None): + super().__init__(self.model, initial_prompt="") + self.seq = OnDemand(seq or ()) + self.result = [] + self.latest = torch.zeros((0,)) + for i in range(len(self.seq)): + self._seek = i + self.frame_offset = max(0, i + 1 - CHUNK_LENGTH) + res = self.initial_prompt_tokens + assert res[0] == self.seq.prev + self.result.append(res[1:]) + if seq is None: + print(res) + + def detect_language(self, mel=None): + self.result.append([self.sample, mel]) + return self.seq[self._seek] + + def get_tokenizer(self, multilingual, language, **kw): + return MockTokenizer(language, **{"multilingual": multilingual, **kw}) + + @property + def rle(self): + res = [] + for i, *j in self.result: + if i is self.sample: + res.append(0) + else: + res[-1] += 1 + return res + + +def test_language(): + res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle + assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2] diff --git a/whisper/buffer.py b/whisper/buffer.py index 2dd27b4..5229ad8 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -133,7 +133,6 @@ class ArrayStream(AudioSink): amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True ) - # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 log_spec_bound: Optional[torch.Tensor] = None def transform(self, stft: torch.Tensor) -> torch.Tensor: diff --git a/whisper/transcribe.py b/whisper/transcribe.py index aaf261a..dbc81bd 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -4,6 +4,7 @@ import os import traceback import warnings from dataclasses import dataclass +from math import ceil from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np @@ -147,15 +148,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return self._language self._hypothesis.last = self._seek or 0 self._hypothesis.since += 1 - if 2**self._hypothesis.evidence < self._hypothesis.since: + if 2**self._hypothesis.evidence > self._hypothesis.since: return self._hypothesis.language self._hypothesis.since = 0 guess = self.detect_language() if guess == self._hypothesis.language: self._hypothesis.evidence += 1 - self._hypothesis.language = guess - self._hypothesis.evidence = 1 - return None + else: + self._hypothesis.language = guess + self._hypothesis.evidence = 0 + return guess @PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): @@ -257,18 +259,34 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if self._initial_prompt_tokens is None: if self.initial_prompt is None: self._initial_prompt_tokens = [] + elif self.language is None: + return [] else: tokenizer = self.tokenizer - if tokenizer is None: - return [] if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( " " + self.initial_prompt.strip() ) - self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] + if self._tokenizer is not None: + self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens + _initial_tokens: int = 0 + _initial_finalized: bool = False + _all_tokens: Optional[list] = None + + @property + def all_tokens(self): + if self._all_tokens is None: + self._all_tokens = [] + if not self._initial_finalized: + initial = self.initial_prompt_tokens + self._all_tokens = initial + self._all_tokens[self._initial_tokens :] + self._initial_tokens = len(initial) + self._initial_finalized = self._initial_prompt_tokens is not None + return self._all_tokens + prompt_reset_since: int = 0 last_speech_timestamp: float = 0.0 frame_offset: int = 0 @@ -375,7 +393,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.hallucination_silence_threshold = hallucination_silence_threshold self.decode_options = decode_options - self.all_tokens = self.initial_prompt_tokens[:] self.all_segments = [] def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult: @@ -784,7 +801,7 @@ class ProgressTranscriber(MinimalTranscriber): n = ( self.latest.shape[-1] if self.duration is None - else -int(self.duration * -FRAMES_PER_SECOND) + else ceil(self.duration * FRAMES_PER_SECOND) ) # show the progress bar when verbose is False # (if True, transcribed text will be printed)