mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
language detection patch and test
This commit is contained in:
parent
610f82ffba
commit
247391a2af
@ -4,7 +4,9 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
|
from whisper.audio import CHUNK_LENGTH
|
||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
from whisper.transcribe import Transcriber
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
@ -40,3 +42,79 @@ def test_transcribe(model_name: str):
|
|||||||
timing_checked = True
|
timing_checked = True
|
||||||
|
|
||||||
assert timing_checked
|
assert timing_checked
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
def __init__(self, language, **kw):
|
||||||
|
self.language, self._kw = language, kw
|
||||||
|
for k, v in kw.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def encode(self, prompt):
|
||||||
|
return [self.language, self, prompt]
|
||||||
|
|
||||||
|
|
||||||
|
class OnDemand:
|
||||||
|
def __init__(self, seq=(), relative=True):
|
||||||
|
self.seq, self.relative = seq, relative
|
||||||
|
self.prev, self.given = 0, 0
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
_key = self.given if self.relative else key
|
||||||
|
self.prev = (
|
||||||
|
self.seq[_key]
|
||||||
|
if _key < len(self.seq)
|
||||||
|
else int(input(f"lang @ {_key}: ") or self.prev)
|
||||||
|
)
|
||||||
|
self.given += 1
|
||||||
|
return self.prev
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return CHUNK_LENGTH + 2 if self.relative else len(self.seq)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriberTest(Transcriber):
|
||||||
|
sample = object()
|
||||||
|
dtype = torch.float32
|
||||||
|
model = type(
|
||||||
|
"MockModel",
|
||||||
|
(),
|
||||||
|
{"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")},
|
||||||
|
)()
|
||||||
|
_seek = 0
|
||||||
|
|
||||||
|
def __init__(self, seq=None):
|
||||||
|
super().__init__(self.model, initial_prompt="")
|
||||||
|
self.seq = OnDemand(seq or ())
|
||||||
|
self.result = []
|
||||||
|
self.latest = torch.zeros((0,))
|
||||||
|
for i in range(len(self.seq)):
|
||||||
|
self._seek = i
|
||||||
|
self.frame_offset = max(0, i + 1 - CHUNK_LENGTH)
|
||||||
|
res = self.initial_prompt_tokens
|
||||||
|
assert res[0] == self.seq.prev
|
||||||
|
self.result.append(res[1:])
|
||||||
|
if seq is None:
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
def detect_language(self, mel=None):
|
||||||
|
self.result.append([self.sample, mel])
|
||||||
|
return self.seq[self._seek]
|
||||||
|
|
||||||
|
def get_tokenizer(self, multilingual, language, **kw):
|
||||||
|
return MockTokenizer(language, **{"multilingual": multilingual, **kw})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rle(self):
|
||||||
|
res = []
|
||||||
|
for i, *j in self.result:
|
||||||
|
if i is self.sample:
|
||||||
|
res.append(0)
|
||||||
|
else:
|
||||||
|
res[-1] += 1
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def test_language():
|
||||||
|
res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle
|
||||||
|
assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2]
|
||||||
|
|||||||
@ -133,7 +133,6 @@ class ArrayStream(AudioSink):
|
|||||||
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
|
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149
|
|
||||||
log_spec_bound: Optional[torch.Tensor] = None
|
log_spec_bound: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def transform(self, stft: torch.Tensor) -> torch.Tensor:
|
def transform(self, stft: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from math import ceil
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -147,15 +148,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
return self._language
|
return self._language
|
||||||
self._hypothesis.last = self._seek or 0
|
self._hypothesis.last = self._seek or 0
|
||||||
self._hypothesis.since += 1
|
self._hypothesis.since += 1
|
||||||
if 2**self._hypothesis.evidence < self._hypothesis.since:
|
if 2**self._hypothesis.evidence > self._hypothesis.since:
|
||||||
return self._hypothesis.language
|
return self._hypothesis.language
|
||||||
self._hypothesis.since = 0
|
self._hypothesis.since = 0
|
||||||
guess = self.detect_language()
|
guess = self.detect_language()
|
||||||
if guess == self._hypothesis.language:
|
if guess == self._hypothesis.language:
|
||||||
self._hypothesis.evidence += 1
|
self._hypothesis.evidence += 1
|
||||||
self._hypothesis.language = guess
|
else:
|
||||||
self._hypothesis.evidence = 1
|
self._hypothesis.language = guess
|
||||||
return None
|
self._hypothesis.evidence = 0
|
||||||
|
return guess
|
||||||
|
|
||||||
@PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter
|
@PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter
|
||||||
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
|
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
|
||||||
@ -257,18 +259,34 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
if self._initial_prompt_tokens is None:
|
if self._initial_prompt_tokens is None:
|
||||||
if self.initial_prompt is None:
|
if self.initial_prompt is None:
|
||||||
self._initial_prompt_tokens = []
|
self._initial_prompt_tokens = []
|
||||||
|
elif self.language is None:
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
if tokenizer is None:
|
|
||||||
return []
|
|
||||||
if tokenizer not in self._initial_prompt_cache:
|
if tokenizer not in self._initial_prompt_cache:
|
||||||
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
||||||
" " + self.initial_prompt.strip()
|
" " + self.initial_prompt.strip()
|
||||||
)
|
)
|
||||||
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
|
if self._tokenizer is not None:
|
||||||
|
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
|
||||||
return self._initial_prompt_cache[tokenizer]
|
return self._initial_prompt_cache[tokenizer]
|
||||||
return self._initial_prompt_tokens
|
return self._initial_prompt_tokens
|
||||||
|
|
||||||
|
_initial_tokens: int = 0
|
||||||
|
_initial_finalized: bool = False
|
||||||
|
_all_tokens: Optional[list] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_tokens(self):
|
||||||
|
if self._all_tokens is None:
|
||||||
|
self._all_tokens = []
|
||||||
|
if not self._initial_finalized:
|
||||||
|
initial = self.initial_prompt_tokens
|
||||||
|
self._all_tokens = initial + self._all_tokens[self._initial_tokens :]
|
||||||
|
self._initial_tokens = len(initial)
|
||||||
|
self._initial_finalized = self._initial_prompt_tokens is not None
|
||||||
|
return self._all_tokens
|
||||||
|
|
||||||
prompt_reset_since: int = 0
|
prompt_reset_since: int = 0
|
||||||
last_speech_timestamp: float = 0.0
|
last_speech_timestamp: float = 0.0
|
||||||
frame_offset: int = 0
|
frame_offset: int = 0
|
||||||
@ -375,7 +393,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
self.hallucination_silence_threshold = hallucination_silence_threshold
|
self.hallucination_silence_threshold = hallucination_silence_threshold
|
||||||
self.decode_options = decode_options
|
self.decode_options = decode_options
|
||||||
|
|
||||||
self.all_tokens = self.initial_prompt_tokens[:]
|
|
||||||
self.all_segments = []
|
self.all_segments = []
|
||||||
|
|
||||||
def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult:
|
def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult:
|
||||||
@ -784,7 +801,7 @@ class ProgressTranscriber(MinimalTranscriber):
|
|||||||
n = (
|
n = (
|
||||||
self.latest.shape[-1]
|
self.latest.shape[-1]
|
||||||
if self.duration is None
|
if self.duration is None
|
||||||
else -int(self.duration * -FRAMES_PER_SECOND)
|
else ceil(self.duration * FRAMES_PER_SECOND)
|
||||||
)
|
)
|
||||||
# show the progress bar when verbose is False
|
# show the progress bar when verbose is False
|
||||||
# (if True, transcribed text will be printed)
|
# (if True, transcribed text will be printed)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user