diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 599221a..39108d7 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -4,7 +4,9 @@ import pytest import torch import whisper +from whisper.audio import CHUNK_LENGTH from whisper.tokenizer import get_tokenizer +from whisper.transcribe import Transcriber @pytest.mark.parametrize("model_name", whisper.available_models()) @@ -40,3 +42,79 @@ def test_transcribe(model_name: str): timing_checked = True assert timing_checked + + +class MockTokenizer: + def __init__(self, language, **kw): + self.language, self._kw = language, kw + for k, v in kw.items(): + setattr(self, k, v) + + def encode(self, prompt): + return [self.language, self, prompt] + + +class OnDemand: + def __init__(self, seq=(), relative=True): + self.seq, self.relative = seq, relative + self.prev, self.given = 0, 0 + + def __getitem__(self, key): + _key = self.given if self.relative else key + self.prev = ( + self.seq[_key] + if _key < len(self.seq) + else int(input(f"lang @ {_key}: ") or self.prev) + ) + self.given += 1 + return self.prev + + def __len__(self): + return CHUNK_LENGTH + 2 if self.relative else len(self.seq) + + +class TranscriberTest(Transcriber): + sample = object() + dtype = torch.float32 + model = type( + "MockModel", + (), + {"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")}, + )() + _seek = 0 + + def __init__(self, seq=None): + super().__init__(self.model, initial_prompt="") + self.seq = OnDemand(seq or ()) + self.result = [] + self.latest = torch.zeros((0,)) + for i in range(len(self.seq)): + self._seek = i + self.frame_offset = max(0, i + 1 - CHUNK_LENGTH) + res = self.initial_prompt_tokens + assert res[0] == self.seq.prev + self.result.append(res[1:]) + if seq is None: + print(res) + + def detect_language(self, mel=None): + self.result.append([self.sample, mel]) + return self.seq[self._seek] + + def get_tokenizer(self, multilingual, language, **kw): + return MockTokenizer(language, **{"multilingual": multilingual, **kw}) + + @property + def rle(self): + res = [] + for i, *j in self.result: + if i is self.sample: + res.append(0) + else: + res[-1] += 1 + return res + + +def test_language(): + res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle + assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2] diff --git a/whisper/batching.py b/whisper/batching.py new file mode 100644 index 0000000..d8e0e12 --- /dev/null +++ b/whisper/batching.py @@ -0,0 +1,239 @@ +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from typing import Generic, TypeVar, Union + +import numpy as np +import torch + +A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor]) + + +class ArrayWrapper(Generic[A]): + pass + + +ArrayTypes = Union[A, ArrayWrapper[A]] + + +class LoopbackIterator(Generic[A]): + async def iter(self): + raise NotImplementedError + + def __aiter__(self): + self._iter = self.iter() + return self + + async def __anext__(self) -> ArrayTypes: + if not hasattr(self, "_iter"): + self.__aiter__() + return await anext(self._iter) + + +async def empty(): + return + yield + + +class Unwrap(LoopbackIterator): + _initial: Union[ArrayTypes, Awaitable[ArrayTypes]] + started: bool + iterator: AsyncIterable[ArrayTypes] + + def __init__(self, iterator: AsyncIterable[ArrayTypes]): + while isinstance(iterator, PassthroughTransform): + iterator = iterator.handoff() + if isinstance(iterator, Unwrap): + self._initial, self.started = iterator.initial(), iterator.started + self.iterator = iterator.iterator + return + elif not isinstance(iterator, AsyncIterator): + iterator = aiter(iterator) + try: + self._initial = anext(iterator) + self.iterator, self.started = iterator, False + except StopAsyncIteration: + self.iterator, self.started = empty(), True + + async def initial(self) -> ArrayTypes: + while isinstance(self._initial, Awaitable): + self._initial = await self._initial + return self._initial + + async def iter(self) -> AsyncIterator[ArrayTypes]: + if not self.started: + self.started = True + yield await self.initial() + async for i in self.iterator: + yield i + + async def prop(self, key: str, default): + if hasattr(self, "initial"): + return getattr(await self.initial(), key) + else: + return default + + @property + def shape(self): + return self.prop("shape", ()) + + @property + def dtype(self): + return self.prop("dtype", None) + + @property + async def concat(self): + return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat + + +class PassthroughTransform(LoopbackIterator): + def handoff(self) -> AsyncIterable[ArrayTypes]: + raise NotImplementedError + + +class BoxedIterator(PassthroughTransform): + def __init__(self, iterator): + self.iterator = iterator + self.flag = object() + + def handoff(self) -> AsyncIterable[ArrayTypes]: + self.flag = None + return self.iterator + + async def iter(self) -> AsyncIterator[ArrayTypes]: + if self.flag is None: + raise Exception("iterator source removed") + self.flag = flag = object() + async for i in self.iterator: + yield i + if self.flag != flag: + raise Exception("source can only be used by one iterator") + + +def LookAlong(axis: int): + assert axis >= 0 + empties = (slice(None),) * axis + + class LookAlong(ArrayWrapper): + def __init__(self, value: A): + self.value = value + + @property + def shape(self): + return self.value.shape[axis] + + def __getitem__(self, idx): + return self.value[empties + (idx,)] + + def __next__(self): + return self.value + + return LookAlong + + +class PassthroughMap(PassthroughTransform): + def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]): + self.iterator, self.apply = iterator, apply + + def handoff(self) -> AsyncIterator[A]: + return self.iterator + + async def iter(self) -> AsyncIterator[ArrayTypes]: + async for i in self.iterator: + yield self.apply(i) + + +class Group: + def __init__(self, concat, axis=-1): + self.concat = concat + self.holding = [] + self.consumed = 0 + self.shape = 0 + + def add(self, value): + self.holding.append(value) + self.shape += value.shape + + def take(self, amount, exact=True): + assert amount > 0 and amount <= self.shape + self.shape -= amount + taking, start = -self.consumed, self.consumed + for i, x in enumerate(self.holding): + taking += x.shape + if taking >= amount: + self.consumed = amount - taking + x.shape + break + if taking == amount or not exact: + self.shape += amount - taking + self.consumed = 0 + res = self.concat( + [self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]] + ) + self.holding = self.holding[i + 1 :] + return res + if i == 0: + return self.holding[0][start : self.consumed] + res = self.concat( + [self.holding[0][start:]] + + [i.value for i in self.holding[1:i]] + + [self.holding[i][: self.consumed]] + ) + self.holding = self.holding[i:] + return res + + def all(self): + res = self.concat([i.value for i in self.holding]) + self.shape = 0 + self.consumed = 0 + self.holding = [] + return res + + +class Taken: + def take(self, *a, **kw): + raise Exception("batch queue moved") + + +class Batcher(PassthroughTransform): + def __init__(self, iterator, size, axis=-1, exact=False): + assert isinstance(size, int) and size > 0 + self.size, self._axis, self.exact = size, axis, exact + if isinstance(iterator, Unwrap) and hasattr(iterator, "group"): + self.group = iterator.group + self.preview = Unwrap(iterator) + + async def concat(self): + f = await self.preview.concat + return lambda tensors: f(tensors, self.axis) + + _iterator = None + + async def iterator(self): + if self._iterator is None: + self.axis = ( + len(await self.preview.shape) + self._axis + if self._axis < 0 + else self._axis + ) + if not hasattr(self, "group"): + self.group = Group(await self.concat()) + self._iterator = PassthroughMap( + LookAlong(self.axis), BoxedIterator(self.preview) + ) + return self._iterator + + def handoff(self): + self.group = Taken() + return self.preview if self._iterator is None else self._iterator + + def __aiter__(self): + return self + + async def __anext__(self): + iterator = aiter(await self.iterator()) + while self.group.shape < self.size: + try: + self.group.add(await anext(iterator)) + except StopAsyncIteration: + if self.group.shape > 0: + return self.group.all() + raise + return self.group.take(self.size, self.exact) diff --git a/whisper/buffer.py b/whisper/buffer.py new file mode 100644 index 0000000..5229ad8 --- /dev/null +++ b/whisper/buffer.py @@ -0,0 +1,309 @@ +import asyncio +import json +import subprocess +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine +from typing import IO, Optional, Union + +import numpy as np +import torch + +from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters +from .batching import Batcher +from .utils import PathType, ceildiv + + +class AudioSink: + def __init__(self, *, rate: int = SAMPLE_RATE, **kw): + super().__init__(**kw) + self.rate = rate + + def read(self): + raise NotImplementedError + + def write(self, data): + raise NotImplementedError + + +class ArrayStream(AudioSink): + q: asyncio.Queue + + def __init__( + self, + *, + device: Optional[Union[str, torch.device]] = None, + batch: int = 1, + n_mels: int = 80, + capacity: int = 1_000_000, + **kw, + ): + super().__init__(**kw) + self.q = asyncio.Queue(capacity) + self.finished = asyncio.Event() + self.device, self.batch, self.n_mels = device, batch, n_mels + self.sees = self.zeros((0,)) + self.spectogram = self.zeros((n_mels, 0)) + self.hann = torch.hann_window(N_FFT).to(self.sees.device) + self.filters = mel_filters(self.sees.device, n_mels) + + def zeros(self, shape): + return torch.zeros(shape, dtype=torch.float32, device=self.device) + + write_blockable: bool = True + + def write(self, data: bytes) -> Optional[Coroutine]: + if self.write_blockable: + return self.q.put(data) + else: + self.q.put_nowait(data) + return None + + def load(self, data: bytes) -> np.ndarray: + return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0 + + async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]: + async for data in iterator: + yield self.load(data) + + async def buffer(self) -> AsyncIterator[bytes]: + waiter = asyncio.create_task(self.finished.wait()) + while not self.finished.is_set(): + getter = asyncio.create_task(self.q.get()) + done, pending = await asyncio.wait( + (waiter, getter), return_when=asyncio.FIRST_COMPLETED + ) + if getter in done: + yield getter.result() + while not self.q.empty(): + yield self.q.get_nowait() + + async def buffer_nowait(self) -> AsyncIterator[bytes]: + try: + while True: + yield self.q.get_nowait() + except asyncio.QueueEmpty: + pass + + loading: Optional[Batcher] = None + + async def fft_offset( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[np.ndarray]: + init = self.loader(iterator) if self.loading is None else self.loading + self.loading = Batcher(init, HOP_LENGTH) + _iterator = aiter(self.loading) + window = np.zeros((0,), dtype=np.float32) + while window.size < ceildiv(N_FFT, 2): + try: + window = np.concatenate((window, await anext(_iterator))) + except StopAsyncIteration: + return + window = np.pad(window, (N_FFT // 2, 0), "reflect") + yield window + async for data in _iterator: + yield data + # for _ in range(N_FFT // HOP_LENGTH): + # yield np.zeros((HOP_LENGTH,), dtype=np.float32) + # (done by runoff) + + def seeing(self, sees: torch.Tensor) -> torch.Tensor: + hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH + return sees[hopped:] + + async def window( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[torch.Tensor]: + _iterator = self.fft_offset(iterator) + async for data in _iterator: + _data = torch.from_numpy(data) + prev = self.sees.shape[0] - N_FFT + while (_data.shape[0] + prev) // HOP_LENGTH < self.batch - 1: + try: + adding = torch.from_numpy(await anext(_iterator)) + except StopAsyncIteration: + break + _data = torch.cat((_data, adding)) + if self.device is not None: + _data.to(self.device) + res = torch.cat((self.sees, _data)) + self.sees = self.seeing(res) + yield self.transform(self.dft(res)) + + def dft(self, amp: torch.Tensor) -> torch.Tensor: + return torch.stft( + amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True + ) + + log_spec_bound: Optional[torch.Tensor] = None + + def transform(self, stft: torch.Tensor) -> torch.Tensor: + magnitudes = stft.abs() ** 2 + mel_spec = self.filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + # causes values to not precisely match the original + self.log_spec_bound = ( + log_spec.max() + if self.log_spec_bound is None + else torch.maximum(log_spec.max(), self.log_spec_bound) + ) + log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + def padding(self, content_frames: int) -> int: + return N_FRAMES + + # dft_pad: add ending content frames to match padding from a centered STFT + dft_pad: bool = False + + def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor: + dft_pad = self.dft_pad if dft_pad is None else dft_pad + if dft_pad: + overrun = (ceildiv(N_FFT, HOP_LENGTH) - 1) * HOP_LENGTH + spectogram = torch.cat((self.sees, self.zeros(overrun))) + if spectogram.shape[-1] >= N_FFT: + spectogram = self.transform(self.dft(spectogram)) + else: + spectogram = torch.zeros(0) + padding = self.padding(self.spectogram.shape[-1] + spectogram.shape[-1]) + pad = self.zeros((self.n_mels, max(0, padding))) + spectogram = torch.cat((self.spectogram, spectogram, pad), -1) + return spectogram if padding >= 0 else spectogram[-padding:] + + offset: int = 0 + + async def pull(self) -> torch.Tensor: + context = self.spectogram.shape[-1] + iterator = self.window(self.buffer_nowait()) + async for frame in iterator: + self.spectogram = torch.cat((self.spectogram, frame), -1) + cutoff = min(context, max(self.spectogram.shape[-1] - N_FRAMES, 0)) + self.offset += cutoff + self.spectogram = self.spectogram[:, cutoff:] + return self.runoff() + + staging: Optional[Batcher] = None + + async def _push( + self, sec: float, exact: bool = False + ) -> AsyncIterator[torch.Tensor]: + batching = int(sec * SAMPLE_RATE // HOP_LENGTH) + init = self.window(self.buffer()) if self.staging is None else self.staging + self.staging = Batcher(init, batching, exact=exact) + async for frame in self.staging: + batched = batching if exact else frame.shape[-1] + cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0) + self.offset += cutoff + self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1) + yield self.runoff() + + reader: Optional[Awaitable] = None + + def start(self, **kw) -> None: + if self.reader is None: + self.reader = asyncio.create_task(self.read(**kw)) + + async def push( + self, sec: float, exact: bool = False, **kw + ) -> AsyncIterator[torch.Tensor]: + self.start(**kw) + async for i in self._push(sec, exact): + yield i + assert self.reader is not None + await self.reader + + async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor: + try: + return await anext(self.push(sec, exact)) + except StopAsyncIteration: + if self.reader is not None: + await self.reader + return self.zeros((self.n_mels, 0)) + + async def full(self, **kw) -> torch.Tensor: + await self.read(**kw) + return await self.pull() + + def sequential(self, **kw) -> torch.Tensor: + return asyncio.run(self.full(**kw)) + + async def amplitudes(self, **kw) -> np.ndarray: + self.start(**kw) + res = [] + async for data in self.loader(self.buffer()): + res.append(data) + assert self.reader is not None + await self.reader + return np.concatenate(res) + + def all_amplitudes(self, **kw) -> np.ndarray: + return asyncio.run(self.amplitudes(**kw)) + + +class RawAudioFile(ArrayStream): + def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw): + super().__init__(**kw) + self.fname = fname + self.period = period + + fp: Optional[IO[bytes]] = None + + async def read(self) -> None: + fp = open(self.fname, "rb") if self.fp is None else self.fp + data = fp.read(self.period) + while len(data) != 0: + io_hold = self.write(data) + assert io_hold is not None and self.write_blockable is True + await io_hold + data = fp.read(self.period) + self.finished.set() + + +class AudioFile(RawAudioFile): + def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw): + assert not subprocess.run( + ["which", "ffmpeg"], stdout=subprocess.PIPE + ).returncode + super().__init__(period=period or -1, fname=fname, **kw) + + async def read(self) -> None: + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", + "0", + "-i", + self.fname, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(self.rate), + "-", + ] + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.fp = ps.stdout + await super().read() + _, stderr = ps.communicate() + if ps.returncode not in (None, 0): + raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + + @property + def duration(self): + cmd = [ + "ffprobe", + "-hide_banner", + "-show_format", + "-of", + "json", + "-i", + self.fname, + ] + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = ps.communicate() + if ps.returncode not in (None, 0): + raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + return float(json.loads(stdout)["format"]["duration"]) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8e1240b..3887433 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,26 +1,31 @@ import argparse +import asyncio import os import traceback import warnings -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from dataclasses import dataclass +from math import ceil +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch import tqdm from .audio import ( + CHUNK_LENGTH, FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, - N_SAMPLES, SAMPLE_RATE, - log_mel_spectrogram, pad_or_trim, ) +from .buffer import ArrayStream, AudioFile from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer from .utils import ( + PassthroughProperty, + PassthroughPropertyDefaults, exact_div, format_timestamp, get_end, @@ -35,24 +40,802 @@ if TYPE_CHECKING: from .model import Whisper -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor], - *, - verbose: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - **decode_options, -): +@dataclass +class LanguageHypothesis: + language: Optional[str] = None + since: int = 0 + evidence: int = 0 + last: int = -1 + + +class Transcriber(metaclass=PassthroughPropertyDefaults): + prefix: str = """"'\u201c\u00bf([{-""" + postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001""" + punctuation: str = prefix + postfix + + verbose: Optional[bool] = None + + _decode_options: dict = {} + decode_props: Tuple[str, ...] = ("fp16", "language", "task") + + @property + def decode_options(self) -> dict: + for k in self.decode_props: + self._decode_options[k] = getattr(self, k) + return self._decode_options + + @decode_options.setter + def decode_options(self, value: dict) -> None: + self._decode_options = value + for k in self.decode_props: + if k in value: + setattr(self, k, value[k]) + + dtype: torch.dtype = torch.float16 + + @property + def fp16(self) -> bool: + return self.dtype == torch.float16 + + @fp16.setter + def fp16(self, value: bool) -> None: + self.dtype = torch.float16 if value else torch.float32 + self.fp16device() + + @PassthroughProperty[Optional["Whisper"]](None).setter + def model(self, value: Optional["Whisper"]) -> None: + self._model = value + self.device = None if value is None else value.device + self.input_stride = exact_div( + N_FRAMES, self.model.dims.n_audio_ctx + ) # mel frames per output token: 2 + self.time_precision = ( + self.input_stride * HOP_LENGTH / SAMPLE_RATE + ) # time per output token: 0.02 (seconds) + + @PassthroughProperty[Optional[torch.device]](None).setter + def device(self, value: Optional[torch.device]) -> None: + self._device = value + if value == torch.device("cpu"): + if torch.cuda.is_available(): + warnings.warn("Performing inference on CPU when CUDA is available") + self.fp16device() + + def fp16device(self) -> None: + if self.device == torch.device("cpu") and self.dtype == torch.float16: + warnings.warn("FP16 is not supported on CPU; using FP32 instead") + self.dtype = torch.float32 + + def detect_language(self, mel: Optional[torch.Tensor] = None) -> str: + mel_segment = pad_or_trim(self.latest if mel is None else mel, N_FRAMES) + mel_segment = mel_segment.to(self.device).to(self.dtype) + _, probs = self.model.detect_language(mel_segment) + return max(probs, key=probs.get) + + prev: Optional[torch.Tensor] = None + _latest: Optional[torch.Tensor] = None + + @PassthroughProperty[Optional[torch.Tensor]](None).setter + def latest(self, value: Optional[torch.Tensor]) -> None: + self.prev = self._latest + self._latest = value + + _hypothesis: LanguageHypothesis = LanguageHypothesis() + _language: Optional[str] + _language_detection_warned: bool = False + + @PassthroughProperty[Optional[str]](None).property + def language(self) -> Optional[str]: + if self._language is not None: + return self._language + if not self.model.is_multilingual: + return "en" + if self.verbose and not self._language_detection_warned: + print( + "Detecting language using up to the first 30 seconds." + "Use `--language` to specify the language" + ) + self._language_detection_warned = True + if self.latest is None: + return None + available = self.frame_offset + self.latest.shape[-1] + if available == self._hypothesis.last: + return self._hypothesis.language + self._hypothesis.last = available + if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: + mel = ( + self.latest + if self.prev is None + else torch.cat((self.prev[: self.frame_offset], self.latest), -1) + ) + self._language = self.detect_language(mel) + return self._language + self._hypothesis.since += 1 + if 2**self._hypothesis.evidence > self._hypothesis.since: + return self._hypothesis.language + self._hypothesis.since = 0 + guess = self.detect_language() + if guess == self._hypothesis.language: + self._hypothesis.evidence += 1 + else: + self._hypothesis.language = guess + self._hypothesis.evidence = 0 + return guess + + @PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter + def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): + self._seek_clips = None + if isinstance(value, str): + self._clip_timestamps = ( + tuple(map(float, value.split(","))) if value else (0,) + ) + else: + self._clip_timestamps = tuple(value) or (0,) + + _seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None + + @property + def seek_clips(self) -> List[Tuple[int, Optional[int]]]: + if self._seek_clips is None: + seek_points = tuple( + round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps + ) + (None,) + self._seek_clips = list(zip(seek_points[::2], seek_points[1::2])) + return self._seek_clips + + _seek: Optional[int] + + @PassthroughProperty[Optional[int]](None).property + def seek(self) -> Optional[int]: + return self.seek_clips[0][0] if self._seek is None else self._seek + + @PassthroughProperty[int](0).setter + def clip_idx(self, value: int): + self._clip_idx = value + clips = self.seek_clips + if value < len(clips): + self.seek = clips[value][0] + + time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE)) + window_end_time = property( + lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + ) + + _temperature: Union[Optional[float], Tuple[float, ...]] + + @PassthroughProperty[Union[Optional[float], Tuple[float, ...]]]( + (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + ).setter + def temperature(self, value: Union[Optional[float], Tuple[float, ...]]): + self._temperature = ( + (value,) + if isinstance(value, (int, float)) + else (Transcriber._temperature if value is None else value) + ) + + @PassthroughProperty("transcribe").setter + def task(self, value: str): + self._task = value + if self.word_timestamps and value == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + + @PassthroughProperty(False).setter + def word_timestamps(self, value: bool): + self._word_timestamps = value + self.task = self.task + + get_tokenizer = staticmethod(get_tokenizer) + _tokenizer: Optional[Tokenizer] = None + _tokenizer_cache: Dict[str, Tokenizer] = {} + + @property + def tokenizer(self) -> Tokenizer: + if self._tokenizer is None: + lang = self.language + if self._language is not None: + if self._language in self._tokenizer_cache: + self._tokenizer = self._tokenizer_cache[self._language] + else: + self._tokenizer = self.get_tokenizer( + self.model.is_multilingual, + num_languages=self.model.num_languages, + language=self.language, + task=self.task, + ) + return self._tokenizer + assert lang is not None + if lang not in self._tokenizer_cache: + self._tokenizer_cache[lang] = self.get_tokenizer( + self.model.is_multilingual, + num_languages=self.model.num_languages, + language=lang, + task=self.task, + ) + return self._tokenizer_cache[lang] + return self._tokenizer + + _initial_prompt_tokens: Optional[List[int]] = None + _initial_prompt_cache: Dict[Tokenizer, List[int]] = {} + + @property + def initial_prompt_tokens(self) -> List[int]: + if self._initial_prompt_tokens is None: + if self.initial_prompt is None: + self._initial_prompt_tokens = [] + elif self.language is None: + return [] + else: + tokenizer = self.tokenizer + if tokenizer not in self._initial_prompt_cache: + self._initial_prompt_cache[tokenizer] = tokenizer.encode( + " " + self.initial_prompt.strip() + ) + if self._tokenizer is not None: + self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] + return self._initial_prompt_cache[tokenizer] + return self._initial_prompt_tokens + + _initial_tokens: int = 0 + _initial_finalized: bool = False + _all_tokens: Optional[list] = None + + @property + def all_tokens(self): + if self._all_tokens is None: + self._all_tokens = [] + if not self._initial_finalized: + initial = self.initial_prompt_tokens + self._all_tokens = initial + self._all_tokens[self._initial_tokens :] + self._initial_tokens = len(initial) + self._initial_finalized = self._initial_prompt_tokens is not None + return self._all_tokens + + prompt_reset_since: int = 0 + last_speech_timestamp: float = 0.0 + frame_offset: int = 0 + all_segments: List[dict] + + def __init__( + self, + model: "Whisper", + *, + verbose: Optional[bool] = None, + temperature: Union[Optional[float], Tuple[float, ...]] = None, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = prefix, + append_punctuations: str = postfix, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + **decode_options, + ): + """ + Transcribe an audio file using Whisper + + Parameters + ---------- + model: Whisper + The Whisper model instance + + verbose: bool + Whether to display the text being decoded to the console. If True, + displays all the details, If False, displays minimal details. If + None, does not display anything + + temperature: Union[float, Tuple[float, ...]] + Temperature for sampling. It can be a tuple of temperatures, which + will be successively used upon failures according to either + `compression_ratio_threshold` or `logprob_threshold`. + + compression_ratio_threshold: float + If the gzip compression ratio is above this value, treat as failed + + logprob_threshold: float + If the average log probability over sampled tokens is below this + value, treat as failed + + no_speech_threshold: float + If the no_speech probability is higher than this value AND the + average log probability over sampled tokens is below + `logprob_threshold`, consider the segment as silent + + condition_on_previous_text: bool + if True, the previous output of the model is provided as a prompt + for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a + failure loop, such as repetition looping or timestamps going out of + sync. + + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and + dynamic time warping, and include the timestamps for each word in + each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the + next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the + previous word + + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can + be used to provide, or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it + more likely to predict those word correctly. + + decode_options: dict + Keyword arguments to construct `DecodingOptions` instances + + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) + of clips to process. The last end timestamp defaults to the end of + the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this + threshold (in seconds) when a possible hallucination is detected + """ + self.model = model + self.verbose = verbose + self.temperature = temperature + self.compression_ratio_threshold = compression_ratio_threshold + self.logprob_threshold = logprob_threshold + self.no_speech_threshold = no_speech_threshold + self.condition_on_previous_text = condition_on_previous_text + self.initial_prompt = initial_prompt + self.word_timestamps = word_timestamps + self.prepend_punctuations = prepend_punctuations + self.append_punctuations = append_punctuations + self.clip_timestamps = clip_timestamps + self.hallucination_silence_threshold = hallucination_silence_threshold + self.decode_options = decode_options + + self.all_segments = [] + + def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult: + decode_result = None + for t in self.temperature: + kw = {**self.decode_options, "temperature": t} + if t > 0: + # disable beam_size and patience when t > 0 + kw.pop("beam_size", None) + kw.pop("patience", None) + else: + # disable best_of when t == 0 + kw.pop("best_of", None) + decode_result = self.model.decode(segment, DecodingOptions(**kw)) + + needs_fallback = False + if self.compression_ratio_threshold is not None and ( + decode_result.compression_ratio > self.compression_ratio_threshold + ): + needs_fallback = True # too repetitive + if self.logprob_threshold is not None and ( + decode_result.avg_logprob < self.logprob_threshold + ): + needs_fallback = True # average log probability is too low + if self.no_speech_threshold is not None and ( + decode_result.no_speech_prob > self.no_speech_threshold + ): + needs_fallback = False # silence + if not needs_fallback: + break + assert decode_result is not None + return decode_result + + def new_segment( + self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult + ) -> dict: + _tokens = tokens.tolist() + text_tokens = [token for token in _tokens if token < self.tokenizer.eot] + return { + "seek": self.seek, + "start": start, + "end": end, + "text": self.tokenizer.decode(text_tokens), + "tokens": _tokens, + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + } + + # anomalous words are very long/short/improbable + @staticmethod + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(self, segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8] + score = sum(self.word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + @staticmethod + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + + def reseek( + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + tokens: torch.Tensor, + timestamp_tokens: torch.Tensor, + result: DecodingResult, + ): + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + consecutive.add_(1) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_pos = ( + sliced_tokens[0].item() - self.tokenizer.timestamp_begin + ) + end_timestamp_pos = ( + sliced_tokens[-1].item() - self.tokenizer.timestamp_begin + ) + current_segments.append( + self.new_segment( + start=self.time_offset + + start_timestamp_pos * self.time_precision, + end=self.time_offset + end_timestamp_pos * self.time_precision, + tokens=sliced_tokens, + result=result, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last + # timestamp. + self.seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last + # timestamp + last_timestamp_pos = ( + tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin + ) + self.seek += last_timestamp_pos * self.input_stride + else: + duration = segment_size * HOP_LENGTH / SAMPLE_RATE + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if ( + len(timestamps) > 0 + and timestamps[-1].item() != self.tokenizer.timestamp_begin + ): + # no consecutive timestamps but it has a timestamp; use the last + # one. + last_timestamp_pos = ( + timestamps[-1].item() - self.tokenizer.timestamp_begin + ) + duration = last_timestamp_pos * self.time_precision + + current_segments.append( + self.new_segment( + start=self.time_offset, + end=self.time_offset + duration, + tokens=tokens, + result=result, + ) + ) + self.seek += segment_size + + def timestamp( + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + mel_segment: torch.Tensor, + previous_seek: int, + content_frames: int, + ) -> bool: + add_word_timestamps( + segments=current_segments, + model=self.model, + tokenizer=self.tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=self.prepend_punctuations, + append_punctuations=self.append_punctuations, + last_speech_timestamp=self.last_speech_timestamp, + ) + + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > self.time_offset: + self.seek = round(last_word_end * FRAMES_PER_SECOND) + + # skip silence before possible hallucinations + if self.hallucination_silence_threshold is not None: + threshold = self.hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > self.time_offset: + remaining_duration = self.window_end_time - last_word_end + if remaining_duration > threshold: + self.seek = round(last_word_end * FRAMES_PER_SECOND) + else: + self.seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = self.next_words_segment(current_segments) + if first_segment is not None and self.is_segment_anomaly(first_segment): + gap = first_segment["start"] - self.time_offset + if gap > threshold: + self.seek = previous_seek + round(gap * FRAMES_PER_SECOND) + return True + + # skip silence before any possible hallucination that is + # surrounded by silence or more hallucinations + hal_last_end = self.last_speech_timestamp + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if self.is_segment_anomaly(segment): + next_segment = self.next_words_segment(current_segments[si + 1 :]) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = ( + self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE + ) + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - self.time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or self.is_segment_anomaly(next_segment) + or self.window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + self.seek = round( + max(self.time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < threshold: + self.seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + self.last_speech_timestamp = last_word_end + return False + + def __call__( + self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False + ) -> dict: + self.latest, self.frame_offset = mel, offset + content_frames = mel.shape[-1] - N_FRAMES + offset + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while self.clip_idx < len(self.seek_clips): + seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] + seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end + if self.seek < seek_clip_start: + self.seek = seek_clip_start + if self.seek >= seek_clip_end: + if self.clip_idx == len(self.seek_clips) - 1: + break + self.clip_idx += 1 + continue + segment_size = min( + N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek + ) + mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset] + mel_segment = ( + pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype) + ) + + self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :] + result: DecodingResult = self.decode_with_fallback(mel_segment) + + if self.no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > self.no_speech_threshold + if ( + self.logprob_threshold is not None + and result.avg_logprob > self.logprob_threshold + ): + # don't skip if the logprob is high enough, despite the + # no_speech_prob + should_skip = False + + if should_skip: + # fast-forward to the next segment boundary + self.seek += segment_size + continue + + previous_seek = self.seek + current_segments: List[dict] = [] + + tokens = torch.tensor(result.tokens) + timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + + self.reseek( + current_segments, + segment_size, + single_timestamp_ending, + tokens, + timestamp_tokens, + result, + ) + + if self.word_timestamps: + if self.timestamp( + current_segments, + segment_size, + single_timestamp_ending, + mel_segment, + previous_seek, + content_frames, + ): + continue + + if self.verbose: + for segment in current_segments: + start, end = segment["start"], segment["end"] + text = segment["text"] + line = ( + f"[{format_timestamp(start)} --> " + f"{format_timestamp(end)}] {text}" + ) + print(make_safe(line)) + + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment["text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + + self.all_segments.extend( + [ + {"id": i, **segment} + for i, segment in enumerate( + current_segments, start=len(self.all_segments) + ) + ] + ) + self.all_tokens.extend( + [token for segment in current_segments for token in segment["tokens"]] + ) + + if not self.condition_on_previous_text or result.temperature > 0.5: + # do not feed the prompt tokens if a high temperature was used + self.prompt_reset_since = len(self.all_tokens) + + self.reporthook() + + if single_pass: + break + + self.result = dict( + segments=self.all_segments, + language=self.language, + text=self.tokenizer.decode( + self.all_tokens[len(self.initial_prompt_tokens) :] + ), + ) + self.latest = None + return self.result + + def reporthook(self) -> None: + pass + + +class InMemoryAudio(AudioFile): + dft_pad = True + + +def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: + if isinstance(audio, str): + return InMemoryAudio(fname=audio).sequential() + if isinstance(audio, np.ndarray): + return torch.from_numpy(audio) + return audio + + +class MinimalTranscriber(Transcriber): + exact: bool = True + # amount of time per chunk that is considered in-context + contextualized: float = CHUNK_LENGTH + + async def process(self, stream: ArrayStream, **kw) -> dict: + data = await stream.request(CHUNK_LENGTH, self.exact) + while data.shape[-1] > 0: + self(data, stream.offset, True) + t = ( + self.contextualized + - (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND + + CHUNK_LENGTH + ) + data = await stream.request(t, self.exact) + return self.result + + +class ProgressTranscriber(MinimalTranscriber): + def __init__(self, *a, duration: Optional[float] = None, **kw): + super().__init__(*a, **kw) + self.duration, self.progress = duration, 0 + + def __call__(self, *a, **kw) -> dict: + if self._pbar is None: + try: + return super().__call__(*a, **kw) + finally: + self.close() + else: + return super().__call__(*a, **kw) + + @PassthroughProperty(None).property + def pbar(self): + if self._pbar is None: + n = ( + self.latest.shape[-1] + if self.duration is None + else ceil(self.duration * FRAMES_PER_SECOND) + ) + # show the progress bar when verbose is False + # (if True, transcribed text will be printed) + self._pbar = tqdm.tqdm( + total=n, unit="frames", disable=self.verbose is not False + ) + self._pbar.__enter__() + return self._pbar + + def reporthook(self) -> None: + update_to = min(self._seek, self.frame_offset + self.latest.shape[-1]) + self.pbar.update(update_to - self.progress) + self.progress = update_to + + def close(self): + self.pbar.__exit__(None, None, None) + + async def process(self, stream: ArrayStream, **kw) -> dict: + self.pbar + try: + return await super().process(stream, **kw) + finally: + self.close() + + async def progressive(self, stream: AudioFile, **kw) -> dict: + self.duration = stream.duration + return await self.process(stream, **kw) + + +def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): """ Transcribe an audio file using Whisper @@ -64,438 +847,18 @@ def transcribe( audio: Union[str, np.ndarray, torch.Tensor] The path to the audio file to open, or the audio waveform - verbose: bool - Whether to display the text being decoded to the console. If True, displays all the details, - If False, displays minimal details. If None, does not display anything - - temperature: Union[float, Tuple[float, ...]] - Temperature for sampling. It can be a tuple of temperatures, which will be successively used - upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. - - compression_ratio_threshold: float - If the gzip compression ratio is above this value, treat as failed - - logprob_threshold: float - If the average log probability over sampled tokens is below this value, treat as failed - - no_speech_threshold: float - If the no_speech probability is higher than this value AND the average log probability - over sampled tokens is below `logprob_threshold`, consider the segment as silent - - condition_on_previous_text: bool - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - word_timestamps: bool - Extract word-level timestamps using the cross-attention pattern and dynamic time warping, - and include the timestamps for each word in each segment. - - prepend_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the next word - - append_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the previous word - - initial_prompt: Optional[str] - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - decode_options: dict - Keyword arguments to construct `DecodingOptions` instances - - clip_timestamps: Union[str, List[float]] - Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. - The last end timestamp defaults to the end of the file. - - hallucination_silence_threshold: Optional[float] - When word_timestamps is True, skip silent periods longer than this threshold (in seconds) - when a possible hallucination is detected - Returns ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. + A dictionary containing the resulting text ("text") and segment-level + details ("segments"), and the spoken language ("language"), which is + detected when `decode_options["language"]` is None. """ - dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 - if model.device == torch.device("cpu"): - if torch.cuda.is_available(): - warnings.warn("Performing inference on CPU when CUDA is available") - if dtype == torch.float16: - warnings.warn("FP16 is not supported on CPU; using FP32 instead") - dtype = torch.float32 + return ProgressTranscriber(model, **kw)(audio_tensor(audio)) - if dtype == torch.float32: - decode_options["fp16"] = False - # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) - content_frames = mel.shape[-1] - N_FRAMES - content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) - - if decode_options.get("language", None) is None: - if not model.is_multilingual: - decode_options["language"] = "en" - else: - if verbose: - print( - "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" - ) - mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) - if verbose is not None: - print( - f"Detected language: {LANGUAGES[decode_options['language']].title()}" - ) - - language: str = decode_options["language"] - task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer( - model.is_multilingual, - num_languages=model.num_languages, - language=language, - task=task, - ) - - if isinstance(clip_timestamps, str): - clip_timestamps = [ - float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) - ] - seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] - if len(seek_points) == 0: - seek_points.append(0) - if len(seek_points) % 2 == 1: - seek_points.append(content_frames) - seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) - - punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" - - if word_timestamps and task == "translate": - warnings.warn("Word-level timestamps on translations may not be reliable.") - - def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = ( - [temperature] if isinstance(temperature, (int, float)) else temperature - ) - decode_result = None - - for t in temperatures: - kwargs = {**decode_options} - if t > 0: - # disable beam_size and patience when t > 0 - kwargs.pop("beam_size", None) - kwargs.pop("patience", None) - else: - # disable best_of when t == 0 - kwargs.pop("best_of", None) - - options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) - - needs_fallback = False - if ( - compression_ratio_threshold is not None - and decode_result.compression_ratio > compression_ratio_threshold - ): - needs_fallback = True # too repetitive - if ( - logprob_threshold is not None - and decode_result.avg_logprob < logprob_threshold - ): - needs_fallback = True # average log probability is too low - if ( - no_speech_threshold is not None - and decode_result.no_speech_prob > no_speech_threshold - ): - needs_fallback = False # silence - if not needs_fallback: - break - - return decode_result - - clip_idx = 0 - seek = seek_clips[clip_idx][0] - input_stride = exact_div( - N_FRAMES, model.dims.n_audio_ctx - ) # mel frames per output token: 2 - time_precision = ( - input_stride * HOP_LENGTH / SAMPLE_RATE - ) # time per output token: 0.02 (seconds) - all_tokens = [] - all_segments = [] - prompt_reset_since = 0 - - if initial_prompt is not None: - initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt_tokens) - else: - initial_prompt_tokens = [] - - def new_segment( - *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult - ): - tokens = tokens.tolist() - text_tokens = [token for token in tokens if token < tokenizer.eot] - return { - "seek": seek, - "start": start, - "end": end, - "text": tokenizer.decode(text_tokens), - "tokens": tokens, - "temperature": result.temperature, - "avg_logprob": result.avg_logprob, - "compression_ratio": result.compression_ratio, - "no_speech_prob": result.no_speech_prob, - } - - # show the progress bar when verbose is False (if True, transcribed text will be printed) - with tqdm.tqdm( - total=content_frames, unit="frames", disable=verbose is not False - ) as pbar: - last_speech_timestamp = 0.0 - # NOTE: This loop is obscurely flattened to make the diff readable. - # A later commit should turn this into a simpler nested loop. - # for seek_clip_start, seek_clip_end in seek_clips: - # while seek < seek_clip_end - while clip_idx < len(seek_clips): - seek_clip_start, seek_clip_end = seek_clips[clip_idx] - if seek < seek_clip_start: - seek = seek_clip_start - if seek >= seek_clip_end: - clip_idx += 1 - if clip_idx < len(seek_clips): - seek = seek_clips[clip_idx][0] - continue - time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) - segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) - mel_segment = mel[:, seek : seek + segment_size] - segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(mel_segment) - tokens = torch.tensor(result.tokens) - - if no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > no_speech_threshold - if ( - logprob_threshold is not None - and result.avg_logprob > logprob_threshold - ): - # don't skip if the logprob is high enough, despite the no_speech_prob - should_skip = False - - if should_skip: - seek += segment_size # fast-forward to the next segment boundary - continue - - previous_seek = seek - current_segments = [] - - # anomalous words are very long/short/improbable - def word_anomaly_score(word: dict) -> float: - probability = word.get("probability", 0.0) - duration = word["end"] - word["start"] - score = 0.0 - if probability < 0.15: - score += 1.0 - if duration < 0.133: - score += (0.133 - duration) * 15 - if duration > 2.0: - score += duration - 2.0 - return score - - def is_segment_anomaly(segment: Optional[dict]) -> bool: - if segment is None or not segment["words"]: - return False - words = [w for w in segment["words"] if w["word"] not in punctuation] - words = words[:8] - score = sum(word_anomaly_score(w) for w in words) - return score >= 3 or score + 0.01 >= len(words) - - def next_words_segment(segments: List[dict]) -> Optional[dict]: - return next((s for s in segments if s["words"]), None) - - timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] - - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - consecutive.add_(1) - if len(consecutive) > 0: - # if the output contains two consecutive timestamp tokens - slices = consecutive.tolist() - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = ( - sliced_tokens[0].item() - tokenizer.timestamp_begin - ) - end_timestamp_pos = ( - sliced_tokens[-1].item() - tokenizer.timestamp_begin - ) - current_segments.append( - new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, - result=result, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin - ) - seek += last_timestamp_pos * input_stride - else: - duration = segment_duration - timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if ( - len(timestamps) > 0 - and timestamps[-1].item() != tokenizer.timestamp_begin - ): - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = ( - timestamps[-1].item() - tokenizer.timestamp_begin - ) - duration = last_timestamp_pos * time_precision - - current_segments.append( - new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - result=result, - ) - ) - seek += segment_size - - if word_timestamps: - add_word_timestamps( - segments=current_segments, - model=model, - tokenizer=tokenizer, - mel=mel_segment, - num_frames=segment_size, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - last_speech_timestamp=last_speech_timestamp, - ) - - if not single_timestamp_ending: - last_word_end = get_end(current_segments) - if last_word_end is not None and last_word_end > time_offset: - seek = round(last_word_end * FRAMES_PER_SECOND) - - # skip silence before possible hallucinations - if hallucination_silence_threshold is not None: - threshold = hallucination_silence_threshold - if not single_timestamp_ending: - last_word_end = get_end(current_segments) - if last_word_end is not None and last_word_end > time_offset: - remaining_duration = window_end_time - last_word_end - if remaining_duration > threshold: - seek = round(last_word_end * FRAMES_PER_SECOND) - else: - seek = previous_seek + segment_size - - # if first segment might be a hallucination, skip leading silence - first_segment = next_words_segment(current_segments) - if first_segment is not None and is_segment_anomaly(first_segment): - gap = first_segment["start"] - time_offset - if gap > threshold: - seek = previous_seek + round(gap * FRAMES_PER_SECOND) - continue - - # skip silence before any possible hallucination that is surrounded - # by silence or more hallucinations - hal_last_end = last_speech_timestamp - for si in range(len(current_segments)): - segment = current_segments[si] - if not segment["words"]: - continue - if is_segment_anomaly(segment): - next_segment = next_words_segment( - current_segments[si + 1 :] - ) - if next_segment is not None: - hal_next_start = next_segment["words"][0]["start"] - else: - hal_next_start = time_offset + segment_duration - silence_before = ( - segment["start"] - hal_last_end > threshold - or segment["start"] < threshold - or segment["start"] - time_offset < 2.0 - ) - silence_after = ( - hal_next_start - segment["end"] > threshold - or is_segment_anomaly(next_segment) - or window_end_time - segment["end"] < 2.0 - ) - if silence_before and silence_after: - seek = round( - max(time_offset + 1, segment["start"]) - * FRAMES_PER_SECOND - ) - if content_duration - segment["end"] < threshold: - seek = content_frames - current_segments[si:] = [] - break - hal_last_end = segment["end"] - - last_word_end = get_end(current_segments) - if last_word_end is not None: - last_speech_timestamp = last_word_end - - if verbose: - for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" - print(make_safe(line)) - - # if a segment is instantaneous or does not contain text, clear it - for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or segment["text"].strip() == "": - segment["text"] = "" - segment["tokens"] = [] - segment["words"] = [] - - all_segments.extend( - [ - {"id": i, **segment} - for i, segment in enumerate( - current_segments, start=len(all_segments) - ) - ] - ) - all_tokens.extend( - [token for segment in current_segments for token in segment["tokens"]] - ) - - if not condition_on_previous_text or result.temperature > 0.5: - # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) - - # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) - - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), - segments=all_segments, - language=language, - ) +def buffered_transcribe(model: "Whisper", audio: str, **kw): + transcriber = ProgressTranscriber(model, **kw) + return asyncio.run(transcriber.progressive(AudioFile(fname=audio))) def cli(): @@ -546,6 +909,7 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once") # fmt: on args = parser.parse_args().__dict__ @@ -575,6 +939,7 @@ def cli(): from . import load_model model = load_model(model_name, device=device, download_root=model_dir) + transcriber = buffered_transcribe if args.pop("buffered") else transcribe writer = get_writer(output_format, output_dir) word_options = [ @@ -594,7 +959,7 @@ def cli(): writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): try: - result = transcribe(model, audio_path, temperature=temperature, **args) + result = transcriber(model, audio_path, temperature=temperature, **args) writer(result, audio_path, **writer_args) except Exception as e: traceback.print_exc() diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..27a0465 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,9 +1,11 @@ import json import os +import pathlib import re import sys +import time import zlib -from typing import Callable, List, Optional, TextIO +from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union system_encoding = sys.getdefaultencoding() @@ -21,11 +23,19 @@ else: return string +PathType = Union[str, pathlib.Path] + + def exact_div(x, y): assert x % y == 0 return x // y +# https://stackoverflow.com/a/17511341/3476782 +def ceildiv(a: Union[int, float], b: Union[int, float]) -> int: + return int(-(a // -b)) + + def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: @@ -68,6 +78,20 @@ def format_timestamp( ) +def hms(sec: float) -> str: + trim = sec < 3600 + h = "" if trim else str(int(sec) // 3600) + ":" + m_fill = " " if trim else "0" + m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":" + s = str(int(sec) % 60).rjust(2, "0") + "." + c = str(round((sec % 1) * 100)).rjust(2, "0") + return h + m + s + c + + +def tod(seconds: float) -> str: + return time.strftime("%H:%M:%S", time.localtime(seconds)) + + def get_start(segments: List[dict]) -> Optional[float]: return next( (w["start"] for s in segments for w in s["words"]), @@ -314,3 +338,47 @@ def get_writer( return write_all return writers[output_format](output_dir) + + +T = TypeVar("T") + + +# boilerplate for property with _{name} storage and passthrough getter/setter +class PassthroughProperty(Generic[T]): + def __init__(self, default: T): + self.value = default + + f: Optional[Callable[[Any, T], None]] = None + + def setter(self, f: Callable[[Any, T], None]): + self.f = f + return self + + g: Optional[property] = None + + def property(self, g: Callable[[Any], T]): + self.g = property(g) + return self + + +class PassthroughPropertyDefaults(type): + def __new__(cls, clsname, bases, attrs): + def closure(f, v): + def prop(self): + return getattr(self, v) + + def setter(self, value): + setattr(self, v, value) + + prop.__name__ = setter.__name__ = f + return property(prop), setter + + updates = {} + for k, v in attrs.items(): + if not isinstance(v, PassthroughProperty): + continue + private = "_" + k + updates[private] = v.value + getter, setter = closure(k, private) + updates[k] = (v.g or getter).setter(v.f or setter) + return super().__new__(cls, clsname, bases, {**attrs, **updates})