From 4ccbd70012bb0a2b382aee06ac29c263fe7ce017 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 13 Jul 2024 22:28:56 -0600 Subject: [PATCH 01/10] refactor transcribe --- whisper/batching.py | 222 ++++++++++ whisper/buffer.py | 273 ++++++++++++ whisper/transcribe.py | 950 +++++++++++++++++++++++++----------------- whisper/utils.py | 72 +++- 4 files changed, 1124 insertions(+), 393 deletions(-) create mode 100644 whisper/batching.py create mode 100644 whisper/buffer.py diff --git a/whisper/batching.py b/whisper/batching.py new file mode 100644 index 0000000..684ecba --- /dev/null +++ b/whisper/batching.py @@ -0,0 +1,222 @@ +import torch +import numpy as np +from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable +from typing import Generic, TypeVar, Union + +A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor]) + +class ArrayWrapper(Generic[A]): + pass + +ArrayTypes = Union[A, ArrayWrapper[A]] + +class LoopbackIterator(Generic[A]): + async def iter(self): + raise NotImplementedError + + def __aiter__(self): + self._iter = self.iter() + return self + + async def __anext__(self) -> ArrayTypes: + if not hasattr(self, "_iter"): + self.__aiter__() + return await anext(self._iter) + +async def empty(): + return + yield + +class Unwrap(LoopbackIterator): + _initial: Union[ArrayTypes, Awaitable[ArrayTypes]] + started: bool + iterator: AsyncIterable[ArrayTypes] + def __init__(self, iterator: AsyncIterable[ArrayTypes]): + while isinstance(iterator, PassthroughTransform): + iterator = iterator.handoff() + if isinstance(iterator, Unwrap): + self._initial, self.started = iterator.initial(), iterator.started + self.iterator = iterator.iterator + return + elif not isinstance(iterator, AsyncIterator): + iterator = aiter(iterator) + try: + self._initial = anext(iterator) + self.iterator, self.started = iterator, False + except StopAsyncIteration: + self.iterator, self.started = empty(), True + + async def initial(self) -> ArrayTypes: + while isinstance(self._initial, Awaitable): + self._initial = await self._initial + return self._initial + + async def iter(self) -> AsyncIterator[ArrayTypes]: + if not self.started: + self.started = True + yield await self.initial() + async for i in self.iterator: + yield i + + async def prop(self, key: str, default): + if hasattr(self, "initial"): + return getattr(await self.initial(), key) + else: + return default + + @property + def shape(self): + return self.prop("shape", ()) + + @property + def dtype(self): + return self.prop("dtype", None) + + @property + async def concat(self): + return np.concatenate if isinstance(await self.dtype, np.dtype) \ + else torch.cat + +class PassthroughTransform(LoopbackIterator): + def handoff(self) -> AsyncIterable[ArrayTypes]: + raise NotImplementedError + +class BoxedIterator(PassthroughTransform): + def __init__(self, iterator): + self.iterator = iterator + self.flag = object() + + def handoff(self) -> AsyncIterable[ArrayTypes]: + self.flag = None + return self.iterator + + async def iter(self) -> AsyncIterator[ArrayTypes]: + if self.flag is None: + raise Exception("iterator source removed") + self.flag = flag = object() + async for i in self.iterator: + yield i + if self.flag != flag: + raise Exception("source can only be used by one iterator") + +def LookAlong(axis: int): + assert axis >= 0 + empties = (slice(None),) * axis + + class LookAlong(ArrayWrapper): + def __init__(self, value: A): + self.value = value + + @property + def shape(self): + return self.value.shape[axis] + + def __getitem__(self, idx): + return self.value[empties + (idx,)] + + def __next__(self): + return self.value + + return LookAlong + +class PassthroughMap(PassthroughTransform): + def __init__( + self, apply: Callable[[A], ArrayTypes], + iterator: AsyncIterator[A]): + self.iterator, self.apply = iterator, apply + + def handoff(self) -> AsyncIterator[A]: + return self.iterator + + async def iter(self) -> AsyncIterator[ArrayTypes]: + async for i in self.iterator: + yield self.apply(i) + +class Group: + def __init__(self, concat, axis=-1): + self.concat = concat + self.holding = [] + self.consumed = 0 + self.shape = 0 + + def add(self, value): + self.holding.append(value) + self.shape += value.shape + + def take(self, amount, exact=True): + assert amount > 0 and amount <= self.shape + self.shape -= amount + taking, start = -self.consumed, self.consumed + for i, x in enumerate(self.holding): + taking += x.shape + if taking >= amount: + self.consumed = amount - taking + x.shape + break + if taking == amount or not exact: + self.shape += amount - taking + self.consumed = 0 + res = self.concat([self.holding[0][start:]] + [ + i.value for i in self.holding[1 : i + 1]]) + self.holding = self.holding[i + 1:] + return res + if i == 0: + return self.holding[0][start:self.consumed] + res = self.concat( + [self.holding[0][start:]] + + [i.value for i in self.holding[1 : i]] + + [self.holding[i][:self.consumed]]) + self.holding = self.holding[i:] + return res + + def all(self): + res = self.concat([i.value for i in self.holding]) + self.shape = 0 + self.consumed = 0 + self.holding = [] + return res + +class Taken: + def take(self, *a, **kw): + raise Exception("batch queue moved") + +class Batcher(PassthroughTransform): + def __init__(self, iterator, size, axis=-1, exact=False): + assert isinstance(size, int) and size > 0 + self.size, self._axis, self.exact = size, axis, exact + if isinstance(iterator, Unwrap) and hasattr(iterator, "group"): + self.group = iterator.group + self.preview = Unwrap(iterator) + + async def concat(self): + f = await self.preview.concat + return lambda tensors: f(tensors, self.axis) + + _iterator = None + async def iterator(self): + if self._iterator is None: + self.axis = len(await self.preview.shape) + self._axis \ + if self._axis < 0 else self._axis + if not hasattr(self, "group"): + self.group = Group(await self.concat()) + self._iterator = PassthroughMap( + LookAlong(self.axis), BoxedIterator(self.preview)) + return self._iterator + + def handoff(self): + self.group = Taken() + return self.preview if self._iterator is None else self._iterator + + def __aiter__(self): + return self + + async def __anext__(self): + iterator = aiter(await self.iterator()) + while self.group.shape < self.size: + try: + self.group.add(await anext(iterator)) + except StopAsyncIteration: + if self.group.shape > 0: + return self.group.all() + raise + return self.group.take(self.size, self.exact) + diff --git a/whisper/buffer.py b/whisper/buffer.py new file mode 100644 index 0000000..3075db1 --- /dev/null +++ b/whisper/buffer.py @@ -0,0 +1,273 @@ +import numpy as np +import asyncio, pathlib, subprocess, torch + +from .audio import ( + SAMPLE_RATE, + N_FFT, + HOP_LENGTH, + N_FRAMES, + mel_filters, +) + +from .utils import PathType, ceildiv +from .batching import Batcher +from typing import Optional, Union, IO, Tuple, Any, Type +from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable + +class AudioSink: + def __init__(self, *, rate: int = SAMPLE_RATE, **kw): + super().__init__(**kw) + self.rate = rate + + def read(self): + raise NotImplementedError + + def write(self, data): + raise NotImplementedError + +class ArrayStream(AudioSink): + q: asyncio.Queue + def __init__( + self, *, device: Optional[Union[str, torch.device]] = None, + batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw): + super().__init__(**kw) + self.q = asyncio.Queue(capacity) + self.finished = asyncio.Event() + self.device, self.batch, self.n_mels = device, batch, n_mels + self.sees = self.zeros((0,)) + self.spectogram = self.zeros((n_mels, 0)) + self.hann = torch.hann_window(N_FFT).to(self.sees.device) + self.filters = mel_filters(self.sees.device, n_mels) + + def zeros(self, shape): + return torch.zeros(shape, dtype=torch.float32, device=self.device) + + write_blockable: bool = True + def write(self, data: bytes) -> Optional[Coroutine]: + if self.write_blockable: + return self.q.put(data) + else: + self.q.put_nowait(data) + return None + + def load(self, data: bytes) -> np.ndarray: + return np.frombuffer( + data, np.int16).flatten().astype(np.float32) / 32768.0 + + async def loader(self, iterator: AsyncIterable[bytes]) -> \ + AsyncIterator[np.ndarray]: + async for data in iterator: + yield self.load(data) + + async def buffer(self) -> AsyncIterator[bytes]: + waiter = asyncio.create_task(self.finished.wait()) + while not self.finished.is_set(): + getter = asyncio.create_task(self.q.get()) + done, pending = await asyncio.wait( + (waiter, getter), return_when=asyncio.FIRST_COMPLETED) + if getter in done: + yield getter.result() + while not self.q.empty(): + yield self.q.get_nowait() + + async def buffer_nowait(self) -> AsyncIterator[bytes]: + try: + while True: + yield self.q.get_nowait() + except asyncio.QueueEmpty: + pass + + loading: Optional[Batcher] = None + async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \ + AsyncIterator[np.ndarray]: + init = self.loader(iterator) if self.loading is None else self.loading + self.loading = Batcher(init, HOP_LENGTH) + _iterator = aiter(self.loading) + window = np.zeros((0,), dtype=np.float32) + while window.size < ceildiv(N_FFT, 2): + try: + window = np.concatenate((window, await anext(_iterator))) + except StopAsyncIteration: + return + window = np.pad(window, (N_FFT // 2, 0), 'reflect') + yield window + async for data in _iterator: + yield data + # for _ in range(N_FFT // HOP_LENGTH): + # yield np.zeros((HOP_LENGTH,), dtype=np.float32) + # (done by runoff) + + def seeing(self, sees: torch.Tensor) -> torch.Tensor: + hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH + return sees[hopped:] + + async def window(self, iterator: AsyncIterable[bytes]) -> \ + AsyncIterator[torch.Tensor]: + _iterator = self.fft_offset(iterator) + async for data in _iterator: + _data = torch.from_numpy(data) + prev = self.sees.shape[0] - N_FFT + while (_data.shape[0] + prev) // HOP_LENGTH < self.batch - 1: + try: + adding = torch.from_numpy(await anext(_iterator)) + except StopAsyncIteration: + break + _data = torch.cat((_data, adding)) + if self.device is not None: + _data.to(self.device) + res = torch.cat((self.sees, _data)) + self.sees = self.seeing(res) + yield self.transform(self.dft(res)) + + def dft(self, amp: torch.Tensor) -> torch.Tensor: + return torch.stft( + amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, + return_complex=True) + + # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 + log_spec_bound: Optional[torch.Tensor] = None + def transform(self, stft: torch.Tensor) -> torch.Tensor: + magnitudes = stft.abs() ** 2 + mel_spec = self.filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + # causes values to not precisely match the original + self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \ + else torch.maximum(log_spec.max(), self.log_spec_bound) + log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + def padding(self, content_frames: int) -> int: + return N_FRAMES + + # dft_pad: add ending content frames to match padding from a centered STFT + dft_pad: bool = False + def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor: + dft_pad = self.dft_pad if dft_pad is None else dft_pad + if dft_pad: + overrun = (ceildiv(N_FFT, HOP_LENGTH) - 1) * HOP_LENGTH + spectogram = torch.cat((self.sees, self.zeros(overrun))) + if spectogram.shape[-1] >= N_FFT: + spectogram = self.transform(self.dft(spectogram)) + else: + spectogram = torch.zeros(0) + padding = self.padding(self.spectogram.shape[-1] + spectogram.shape[-1]) + pad = self.zeros((self.n_mels, max(0, padding))) + spectogram = torch.cat((self.spectogram, spectogram, pad), -1) + return spectogram if padding >= 0 else spectogram[-padding:] + + offset: int = 0 + + async def pull(self) -> torch.Tensor: + context = self.spectogram.shape[-1] + iterator = self.window(self.buffer_nowait()) + async for frame in iterator: + self.spectogram = torch.cat((self.spectogram, frame), -1) + cutoff = min(context, max(self.spectogram.shape[-1] - N_FRAMES, 0)) + self.offset += cutoff + self.spectogram = self.spectogram[:, cutoff:] + return self.runoff() + + staging: Optional[Batcher] = None + async def _push(self, sec: float, exact: bool = False) -> \ + AsyncIterator[torch.Tensor]: + batching = int(sec * SAMPLE_RATE // HOP_LENGTH) + init = self.window(self.buffer()) if self.staging is None \ + else self.staging + self.staging = Batcher(init, batching, exact=exact) + async for frame in self.staging: + batched = batching if exact else frame.shape[-1] + cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0) + self.offset += cutoff + self.spectogram = torch.cat(( + self.spectogram[:, cutoff:], frame), -1) + yield self.runoff() + + reader: Optional[Awaitable] = None + def start(self, **kw) -> None: + if self.reader is None: + self.reader = asyncio.create_task(self.read(**kw)) + + async def push(self, sec: float, exact: bool=False, **kw) -> \ + AsyncIterator[torch.Tensor]: + self.start(**kw) + async for i in self._push(sec, exact): + yield i + assert self.reader is not None + await self.reader + + async def request(self, sec: float, exact: bool=True, **kw) -> torch.Tensor: + try: + return await anext(self.push(sec, exact)) + except StopAsyncIteration: + if self.reader is not None: + await self.reader + return self.zeros((self.n_mels, 0)) + + async def full(self, **kw) -> torch.Tensor: + await self.read(**kw) + return await self.pull() + + def sequential(self, **kw) -> torch.Tensor: + return asyncio.run(self.full(**kw)) + + async def amplitudes(self, **kw) -> np.ndarray: + self.start(**kw) + res = [] + async for data in self.loader(self.buffer()): + res.append(data) + assert self.reader is not None + await self.reader + return np.concatenate(res) + + def all_amplitudes(self, **kw) -> np.ndarray: + return asyncio.run(self.amplitudes(**kw)) + +class RawAudioFile(ArrayStream): + def __init__( + self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw', + **kw): + super().__init__(**kw) + self.fname = fname + self.period = period + + fp: Optional[IO[bytes]] = None + async def read(self) -> None: + fp = open(self.fname, 'rb') if self.fp is None else self.fp + data = fp.read(self.period) + while len(data) != 0: + io_hold = self.write(data) + assert io_hold is not None and self.write_blockable is True + await io_hold + data = fp.read(self.period) + self.finished.set() + +class AudioFile(RawAudioFile): + def __init__( + self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav', + **kw): + assert not subprocess.run( + ["which", "ffmpeg"], stdout=subprocess.PIPE).returncode + super().__init__(period=period or -1, fname=fname, **kw) + + async def read(self) -> None: + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", self.fname, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(self.rate), + "-" + ] + ps = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.fp = ps.stdout + await super().read() + _, stderr = ps.communicate() + if ps.returncode not in (None, 0): + raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..d648196 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,24 +2,24 @@ import argparse import os import traceback import warnings -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict +from dataclasses import dataclass import numpy as np import torch -import tqdm +import tqdm # TODO from .audio import ( FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, - N_SAMPLES, SAMPLE_RATE, - log_mel_spectrogram, + CHUNK_LENGTH, pad_or_trim, ) from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer from .utils import ( exact_div, format_timestamp, @@ -29,473 +29,639 @@ from .utils import ( optional_float, optional_int, str2bool, + PassthroughProperty, + PassthroughPropertyDefaults, ) +from .buffer import AudioFile if TYPE_CHECKING: from .model import Whisper +@dataclass +class LanguageHypothesis: + language: Optional[str] = None + since: int = 0 + evidence: int = 0 + last: int = -1 -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor], - *, - verbose: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - **decode_options, -): - """ - Transcribe an audio file using Whisper +class Transcriber(metaclass=PassthroughPropertyDefaults): + prefix: str = '''"'\u201c\u00bf([{-''' + postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' + punctuation: str = prefix + postfix - Parameters - ---------- - model: Whisper - The Whisper model instance + verbose: bool = False - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform + _decode_options: dict = {} + decode_props: Tuple[str, ...] = ("fp16", "language", "task") + @property + def decode_options(self) -> dict: + for k in self.decode_props: + self._decode_options[k] = getattr(self, k) + return self._decode_options - verbose: bool - Whether to display the text being decoded to the console. If True, displays all the details, - If False, displays minimal details. If None, does not display anything + @decode_options.setter + def decode_options(self, value: dict) -> None: + self._decode_options = value + for k in self.decode_props: + if k in value: + setattr(self, k, value[k]) - temperature: Union[float, Tuple[float, ...]] - Temperature for sampling. It can be a tuple of temperatures, which will be successively used - upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. + dtype: torch.dtype = torch.float16 + @property + def fp16(self) -> bool: + return self.dtype == torch.float16 - compression_ratio_threshold: float - If the gzip compression ratio is above this value, treat as failed + @fp16.setter + def fp16(self, value: bool) -> None: + self.dtype = torch.float16 if value else torch.float32 + self.fp16device() - logprob_threshold: float - If the average log probability over sampled tokens is below this value, treat as failed + @PassthroughProperty(None).setter + def model(self, value: "Whisper") -> None: + self._model = value + self.device = value.device + self.input_stride = exact_div( + N_FRAMES, self.model.dims.n_audio_ctx + ) # mel frames per output token: 2 + self.time_precision = ( + self.input_stride * HOP_LENGTH / SAMPLE_RATE + ) # time per output token: 0.02 (seconds) - no_speech_threshold: float - If the no_speech probability is higher than this value AND the average log probability - over sampled tokens is below `logprob_threshold`, consider the segment as silent + @PassthroughProperty[Optional[torch.device]](None).setter + def device(self, value: Optional[torch.device]) -> None: + self._device = value + if value == torch.device("cpu"): + if torch.cuda.is_available(): + warnings.warn( + "Performing inference on CPU when CUDA is available") + self.fp16device() - condition_on_previous_text: bool - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - word_timestamps: bool - Extract word-level timestamps using the cross-attention pattern and dynamic time warping, - and include the timestamps for each word in each segment. - - prepend_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the next word - - append_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the previous word - - initial_prompt: Optional[str] - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - decode_options: dict - Keyword arguments to construct `DecodingOptions` instances - - clip_timestamps: Union[str, List[float]] - Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. - The last end timestamp defaults to the end of the file. - - hallucination_silence_threshold: Optional[float] - When word_timestamps is True, skip silent periods longer than this threshold (in seconds) - when a possible hallucination is detected - - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. - """ - dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 - if model.device == torch.device("cpu"): - if torch.cuda.is_available(): - warnings.warn("Performing inference on CPU when CUDA is available") - if dtype == torch.float16: + def fp16device(self) -> None: + if self.device == torch.device("cpu") and self.dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") - dtype = torch.float32 + self.dtype = torch.float32 - if dtype == torch.float32: - decode_options["fp16"] = False + def detect_language(self, mel: Optional[torch.Tensor] = None) -> str: + mel_segment = pad_or_trim(self.latest if mel is None else mel, N_FRAMES) + mel_segment = mel_segment.to(self.device).to(self.dtype) + _, probs = self.model.detect_language(mel_segment) + return max(probs, key=probs.get) - # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) - content_frames = mel.shape[-1] - N_FRAMES - content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + prev: Optional[torch.Tensor] = None + _latest: Optional[torch.Tensor] = None + @PassthroughProperty[Optional[torch.Tensor]](None).setter + def latest(self, value: Optional[torch.Tensor]) -> None: + self.prev = self._latest + self._latest = value - if decode_options.get("language", None) is None: - if not model.is_multilingual: - decode_options["language"] = "en" + _hypothesis: LanguageHypothesis = LanguageHypothesis() + _language: Optional[str] + @PassthroughProperty[Optional[str]](None).property + def language(self) -> Optional[str]: + if self._language is not None: + return self._language + if not self.model.is_multilingual: + return "en" + if self.verbose: + print( + "Detecting language using up to the first 30 seconds." + "Use `--language` to specify the language") + if self.latest is None: + return None + if self._seek == self._hypothesis.last: + return self._hypothesis.language + if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: + mel = self.latest if self.prev is None else torch.cat( + (self.prev[:self.frame_offset], self.latest), -1) + self._language = self.detect_language(mel) + return self._language + self._hypothesis.last = self._seek or 0 + self._hypothesis.since += 1 + if 2 ** self._hypothesis.evidence < self._hypothesis.since: + return self._hypothesis.language + self._hypothesis.since = 0 + guess = self.detect_language() + if guess == self._hypothesis.language: + self._hypothesis.evidence += 1 + self._hypothesis.language = guess + self._hypothesis.evidence = 1 + return None + + @PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter + def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): + self._seek_clips = None + if isinstance(value, str): + self._clip_timestamps = tuple(map(float, value.split(","))) \ + if value else (0,) else: - if verbose: - print( - "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" - ) - mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) - if verbose is not None: - print( - f"Detected language: {LANGUAGES[decode_options['language']].title()}" + self._clip_timestamps = tuple(value) or (0,) + + _seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None + @property + def seek_clips(self) -> List[Tuple[int, Optional[int]]]: + if self._seek_clips is None: + seek_points = tuple( + round(ts * FRAMES_PER_SECOND) + for ts in self.clip_timestamps) + (None,) + self._seek_clips = list(zip(seek_points[::2], seek_points[1::2])) + return self._seek_clips + + _seek: Optional[int] + @PassthroughProperty[Optional[int]](None).property + def seek(self) -> Optional[int]: + return self.seek_clips[0][0] if self._seek is None else self._seek + + @PassthroughProperty[int](0).setter + def clip_idx(self, value: int): + self._clip_idx = value + clips = self.seek_clips + if value < len(clips): + self.seek = clips[value][0] + + time_offset = property(lambda self: float( + self.seek * HOP_LENGTH / SAMPLE_RATE)) + window_end_time = property(lambda self: float( + (self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)) + + _temperature: Union[Optional[float], Tuple[float, ...]] + @PassthroughProperty[Union[Optional[float], Tuple[float, ...]]]( + (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter + def temperature(self, value: Union[Optional[float], Tuple[float, ...]]): + self._temperature = (value,) if isinstance(value, (int, float)) else ( + Transcriber._temperature if value is None else value) + + @PassthroughProperty("transcribe").setter + def task(self, value: str): + self._task = value + if self.word_timestamps and value == "translate": + warnings.warn( + "Word-level timestamps on translations may not be " + "reliable.") + + @PassthroughProperty(False).setter + def word_timestamps(self, value: bool): + self._word_timestamps = value + self.task = self.task + + get_tokenizer = staticmethod(get_tokenizer) + _tokenizer: Optional[Tokenizer] = None + _tokenizer_cache: Dict[str, Tokenizer] = {} + @property + def tokenizer(self) -> Optional[Tokenizer]: + if self._tokenizer is None: + lang = self.language + if self._language is not None: + if self._language in self._tokenizer_cache: + self._tokenizer = self._tokenizer_cache[self._language] + else: + self._tokenizer = self.get_tokenizer( + self.model.is_multilingual, + num_languages=self.model.num_languages, + language=self.language, + task=self.task, + ) + return self._tokenizer + if lang is None: + return None + if lang not in self._tokenizer_cache: + self._tokenizer_cache[lang] = self.get_tokenizer( + self.model.is_multilingual, + num_languages=self.model.num_languages, + language=lang, + task=self.task, ) + return self._tokenizer_cache[lang] + return self._tokenizer - language: str = decode_options["language"] - task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer( - model.is_multilingual, - num_languages=model.num_languages, - language=language, - task=task, - ) + _initial_prompt_tokens: Optional[List[int]] = None + _initial_prompt_cache: Dict[Tokenizer, List[int]] = {} + @property + def initial_prompt_tokens(self) -> List[int]: + if self._initial_prompt_tokens is None: + if self.initial_prompt is None: + self._initial_prompt_tokens = [] + else: + tokenizer = self.tokenizer + if tokenizer is None: + return [] + if tokenizer not in self._initial_prompt_cache: + self._initial_prompt_cache[tokenizer] = tokenizer.encode( + " " + self.initial_prompt.strip()) + if self._tokenizer is not None: + self._initial_prompt_tokens = \ + self._initial_prompt_cache[tokenizer] + return self._initial_prompt_cache[tokenizer] + return self._initial_prompt_tokens - if isinstance(clip_timestamps, str): - clip_timestamps = [ - float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) - ] - seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] - if len(seek_points) == 0: - seek_points.append(0) - if len(seek_points) % 2 == 1: - seek_points.append(content_frames) - seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) + prompt_reset_since: int = 0 + last_speech_timestamp: float = 0.0 + frame_offset: int = 0 + all_segments: List[dict] + def __init__( + self, + model: "Whisper", + *, + verbose: Optional[bool] = None, + temperature: Union[Optional[float], Tuple[float, ...]] = None, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = prefix, + append_punctuations: str = postfix, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + **decode_options): + self.model = model + if verbose is not None: + self.verbose = verbose + self.temperature = temperature + self.compression_ratio_threshold = compression_ratio_threshold + self.logprob_threshold = logprob_threshold + self.no_speech_threshold = no_speech_threshold + self.condition_on_previous_text = condition_on_previous_text + self.initial_prompt = initial_prompt + self.word_timestamps = word_timestamps + self.prepend_punctuations = prepend_punctuations + self.append_punctuations = append_punctuations + self.clip_timestamps = clip_timestamps + self.hallucination_silence_threshold = hallucination_silence_threshold + self.decode_options = decode_options - punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + self.all_tokens = self.initial_prompt_tokens[:] + self.all_segments = [] - if word_timestamps and task == "translate": - warnings.warn("Word-level timestamps on translations may not be reliable.") - - def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = ( - [temperature] if isinstance(temperature, (int, float)) else temperature - ) + def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult: decode_result = None - - for t in temperatures: - kwargs = {**decode_options} + for t in self.temperature: + kw = {**self.decode_options, "temperature": t} if t > 0: # disable beam_size and patience when t > 0 - kwargs.pop("beam_size", None) - kwargs.pop("patience", None) + kw.pop("beam_size", None) + kw.pop("patience", None) else: # disable best_of when t == 0 - kwargs.pop("best_of", None) - - options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) + kw.pop("best_of", None) + decode_result = self.model.decode(segment, DecodingOptions(**kw)) needs_fallback = False - if ( - compression_ratio_threshold is not None - and decode_result.compression_ratio > compression_ratio_threshold - ): + if self.compression_ratio_threshold is not None and ( + decode_result.compression_ratio > + self.compression_ratio_threshold): needs_fallback = True # too repetitive - if ( - logprob_threshold is not None - and decode_result.avg_logprob < logprob_threshold - ): + if self.logprob_threshold is not None and ( + decode_result.avg_logprob < self.logprob_threshold): needs_fallback = True # average log probability is too low - if ( - no_speech_threshold is not None - and decode_result.no_speech_prob > no_speech_threshold - ): + if self.no_speech_threshold is not None and ( + decode_result.no_speech_prob > self.no_speech_threshold): needs_fallback = False # silence if not needs_fallback: break - return decode_result - clip_idx = 0 - seek = seek_clips[clip_idx][0] - input_stride = exact_div( - N_FRAMES, model.dims.n_audio_ctx - ) # mel frames per output token: 2 - time_precision = ( - input_stride * HOP_LENGTH / SAMPLE_RATE - ) # time per output token: 0.02 (seconds) - all_tokens = [] - all_segments = [] - prompt_reset_since = 0 - - if initial_prompt is not None: - initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt_tokens) - else: - initial_prompt_tokens = [] - def new_segment( - *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult - ): - tokens = tokens.tolist() - text_tokens = [token for token in tokens if token < tokenizer.eot] + self, *, start: float, end: float, tokens: torch.Tensor, + result: DecodingResult) -> dict: + _tokens = tokens.tolist() + _tokenizer = self.tokenizer + assert _tokenizer is not None + text_tokens = [token for token in _tokens if token < _tokenizer.eot] return { - "seek": seek, + "seek": self.seek, "start": start, "end": end, - "text": tokenizer.decode(text_tokens), - "tokens": tokens, + "text": _tokenizer.decode(text_tokens), + "tokens": _tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob, } - # show the progress bar when verbose is False (if True, transcribed text will be printed) - with tqdm.tqdm( - total=content_frames, unit="frames", disable=verbose is not False - ) as pbar: - last_speech_timestamp = 0.0 - # NOTE: This loop is obscurely flattened to make the diff readable. - # A later commit should turn this into a simpler nested loop. - # for seek_clip_start, seek_clip_end in seek_clips: - # while seek < seek_clip_end - while clip_idx < len(seek_clips): - seek_clip_start, seek_clip_end = seek_clips[clip_idx] - if seek < seek_clip_start: - seek = seek_clip_start - if seek >= seek_clip_end: - clip_idx += 1 - if clip_idx < len(seek_clips): - seek = seek_clips[clip_idx][0] - continue - time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) - segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) - mel_segment = mel[:, seek : seek + segment_size] - segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) + # anomalous words are very long/short/improbable + @staticmethod + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(mel_segment) - tokens = torch.tensor(result.tokens) + def is_segment_anomaly(self, segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [ + w for w in segment["words"] + if w["word"] not in self.punctuation][:8] + score = sum(self.word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) - if no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > no_speech_threshold - if ( - logprob_threshold is not None - and result.avg_logprob > logprob_threshold - ): - # don't skip if the logprob is high enough, despite the no_speech_prob - should_skip = False + @staticmethod + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) - if should_skip: - seek += segment_size # fast-forward to the next segment boundary - continue - - previous_seek = seek - current_segments = [] - - # anomalous words are very long/short/improbable - def word_anomaly_score(word: dict) -> float: - probability = word.get("probability", 0.0) - duration = word["end"] - word["start"] - score = 0.0 - if probability < 0.15: - score += 1.0 - if duration < 0.133: - score += (0.133 - duration) * 15 - if duration > 2.0: - score += duration - 2.0 - return score - - def is_segment_anomaly(segment: Optional[dict]) -> bool: - if segment is None or not segment["words"]: - return False - words = [w for w in segment["words"] if w["word"] not in punctuation] - words = words[:8] - score = sum(word_anomaly_score(w) for w in words) - return score >= 3 or score + 0.01 >= len(words) - - def next_words_segment(segments: List[dict]) -> Optional[dict]: - return next((s for s in segments if s["words"]), None) - - timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] - - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - consecutive.add_(1) - if len(consecutive) > 0: - # if the output contains two consecutive timestamp tokens - slices = consecutive.tolist() - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = ( - sliced_tokens[0].item() - tokenizer.timestamp_begin - ) - end_timestamp_pos = ( - sliced_tokens[-1].item() - tokenizer.timestamp_begin - ) - current_segments.append( - new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, - result=result, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin - ) - seek += last_timestamp_pos * input_stride - else: - duration = segment_duration - timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if ( - len(timestamps) > 0 - and timestamps[-1].item() != tokenizer.timestamp_begin - ): - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = ( - timestamps[-1].item() - tokenizer.timestamp_begin - ) - duration = last_timestamp_pos * time_precision + def reseek( + self, current_segments: List[dict], segment_size: int, + single_timestamp_ending: bool, tokens: torch.Tensor, + timestamp_tokens: torch.Tensor, result: DecodingResult): + _tokenizer = self.tokenizer + assert _tokenizer is not None + consecutive = torch.where( + timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + consecutive.add_(1) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_pos = ( + sliced_tokens[0].item() - + _tokenizer.timestamp_begin) + end_timestamp_pos = ( + sliced_tokens[-1].item() - + _tokenizer.timestamp_begin) current_segments.append( - new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, + self.new_segment( + start=self.time_offset + \ + start_timestamp_pos * self.time_precision, + end=self.time_offset + \ + end_timestamp_pos * self.time_precision, + tokens=sliced_tokens, result=result, ) ) - seek += segment_size + last_slice = current_slice - if word_timestamps: - add_word_timestamps( - segments=current_segments, - model=model, - tokenizer=tokenizer, - mel=mel_segment, - num_frames=segment_size, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - last_speech_timestamp=last_speech_timestamp, - ) + if single_timestamp_ending: + # single timestamp at the end means no speech after the last + # timestamp. + self.seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last + # timestamp + last_timestamp_pos = ( + tokens[last_slice - 1].item() - + _tokenizer.timestamp_begin) + self.seek += last_timestamp_pos * self.input_stride + else: + duration = segment_size * HOP_LENGTH / SAMPLE_RATE + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if len(timestamps) > 0 and \ + timestamps[-1].item() != _tokenizer.timestamp_begin: + # no consecutive timestamps but it has a timestamp; use the last + # one. + last_timestamp_pos = \ + timestamps[-1].item() - _tokenizer.timestamp_begin + duration = last_timestamp_pos * self.time_precision - if not single_timestamp_ending: - last_word_end = get_end(current_segments) - if last_word_end is not None and last_word_end > time_offset: - seek = round(last_word_end * FRAMES_PER_SECOND) + current_segments.append(self.new_segment( + start=self.time_offset, + end=self.time_offset + duration, + tokens=tokens, + result=result)) + self.seek += segment_size - # skip silence before possible hallucinations - if hallucination_silence_threshold is not None: - threshold = hallucination_silence_threshold - if not single_timestamp_ending: - last_word_end = get_end(current_segments) - if last_word_end is not None and last_word_end > time_offset: - remaining_duration = window_end_time - last_word_end - if remaining_duration > threshold: - seek = round(last_word_end * FRAMES_PER_SECOND) - else: - seek = previous_seek + segment_size + def timestamp( + self, current_segments: List[dict], segment_size: int, + single_timestamp_ending: bool, mel_segment: torch.Tensor, + previous_seek: int, content_frames: int) -> bool: + add_word_timestamps( + segments=current_segments, + model=self.model, + tokenizer=self.tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=self.prepend_punctuations, + append_punctuations=self.append_punctuations, + last_speech_timestamp=self.last_speech_timestamp, + ) - # if first segment might be a hallucination, skip leading silence - first_segment = next_words_segment(current_segments) - if first_segment is not None and is_segment_anomaly(first_segment): - gap = first_segment["start"] - time_offset - if gap > threshold: - seek = previous_seek + round(gap * FRAMES_PER_SECOND) - continue + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and \ + last_word_end > self.time_offset: + self.seek = round(last_word_end * FRAMES_PER_SECOND) - # skip silence before any possible hallucination that is surrounded - # by silence or more hallucinations - hal_last_end = last_speech_timestamp - for si in range(len(current_segments)): - segment = current_segments[si] - if not segment["words"]: - continue - if is_segment_anomaly(segment): - next_segment = next_words_segment( - current_segments[si + 1 :] - ) - if next_segment is not None: - hal_next_start = next_segment["words"][0]["start"] - else: - hal_next_start = time_offset + segment_duration - silence_before = ( - segment["start"] - hal_last_end > threshold - or segment["start"] < threshold - or segment["start"] - time_offset < 2.0 - ) - silence_after = ( - hal_next_start - segment["end"] > threshold - or is_segment_anomaly(next_segment) - or window_end_time - segment["end"] < 2.0 - ) - if silence_before and silence_after: - seek = round( - max(time_offset + 1, segment["start"]) - * FRAMES_PER_SECOND - ) - if content_duration - segment["end"] < threshold: - seek = content_frames - current_segments[si:] = [] - break - hal_last_end = segment["end"] + # skip silence before possible hallucinations + if self.hallucination_silence_threshold is not None: + threshold = self.hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and \ + last_word_end > self.time_offset: + remaining_duration = \ + self.window_end_time - last_word_end + if remaining_duration > threshold: + self.seek = round( + last_word_end * FRAMES_PER_SECOND) + else: + self.seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = self.next_words_segment(current_segments) + if first_segment is not None and self.is_segment_anomaly( + first_segment): + gap = first_segment["start"] - self.time_offset + if gap > threshold: + self.seek = previous_seek + round( + gap * FRAMES_PER_SECOND) + return True + + # skip silence before any possible hallucination that is + # surrounded by silence or more hallucinations + hal_last_end = self.last_speech_timestamp + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if self.is_segment_anomaly(segment): + next_segment = self.next_words_segment( + current_segments[si + 1 :]) + if next_segment is not None: + hal_next_start = \ + next_segment["words"][0]["start"] + else: + hal_next_start = self.time_offset + \ + segment_size * HOP_LENGTH / SAMPLE_RATE + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - self.time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or self.is_segment_anomaly(next_segment) + or self.window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + self.seek = round( + max(self.time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < \ + threshold: + self.seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] last_word_end = get_end(current_segments) if last_word_end is not None: - last_speech_timestamp = last_word_end + self.last_speech_timestamp = last_word_end + return False - if verbose: + def __call__( + self, mel: torch.Tensor, offset: int = 0, + single_pass: bool = False) -> dict: + self.latest, self.frame_offset = mel, offset + content_frames = mel.shape[-1] - N_FRAMES + offset + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + while self.clip_idx < len(self.seek_clips): + seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] + seek_clip_end = content_frames if seek_clip_end is None else \ + seek_clip_end + if self.seek < seek_clip_start: + self.seek = seek_clip_start + if self.seek >= seek_clip_end: + if self.clip_idx == len(self.seek_clips) - 1: + break + self.clip_idx += 1 + continue + segment_size = min( + N_FRAMES, content_frames - self.seek, + seek_clip_end - self.seek) + mel_segment = mel[ + :, self.seek - offset : self.seek + segment_size - offset] + mel_segment = pad_or_trim(mel_segment, N_FRAMES).to( + self.device).to(self.dtype) + + self.decode_options["prompt"] = \ + self.all_tokens[self.prompt_reset_since:] + result: DecodingResult = self.decode_with_fallback(mel_segment) + + if self.no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > self.no_speech_threshold + if self.logprob_threshold is not None and \ + result.avg_logprob > self.logprob_threshold: + # don't skip if the logprob is high enough, despite the + # no_speech_prob + should_skip = False + + if should_skip: + # fast-forward to the next segment boundary + self.seek += segment_size + continue + + previous_seek = self.seek + current_segments: List[dict] = [] + + tokens = torch.tensor(result.tokens) + _tokenizer = self.tokenizer + assert _tokenizer is not None + timestamp_tokens: torch.Tensor = tokens.ge( + _tokenizer.timestamp_begin) + single_timestamp_ending = ( + timestamp_tokens[-2:].tolist() == [False, True]) + + self.reseek( + current_segments, segment_size, single_timestamp_ending, + tokens, timestamp_tokens, result) + + if self.word_timestamps: + if self.timestamp( + current_segments, segment_size, single_timestamp_ending, + mel_segment, previous_seek, content_frames): + continue + + if self.verbose: for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + start, end = segment["start"], segment["end"] + text = segment["text"] + line = ( + f"[{format_timestamp(start)} --> " + f"{format_timestamp(end)}] {text}") print(make_safe(line)) # if a segment is instantaneous or does not contain text, clear it for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or segment["text"].strip() == "": + if segment["start"] == segment["end"] or \ + segment["text"].strip() == "": segment["text"] = "" segment["tokens"] = [] segment["words"] = [] - all_segments.extend( - [ + self.all_segments.extend([ {"id": i, **segment} for i, segment in enumerate( - current_segments, start=len(all_segments) - ) - ] - ) - all_tokens.extend( - [token for segment in current_segments for token in segment["tokens"]] - ) + current_segments, start=len(self.all_segments))]) + self.all_tokens.extend([ + token for segment in current_segments + for token in segment["tokens"]]) - if not condition_on_previous_text or result.temperature > 0.5: + if not self.condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) + self.prompt_reset_since = len(self.all_tokens) - # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) + if single_pass: + break - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), - segments=all_segments, - language=language, - ) + _tokenizer = self.tokenizer + assert _tokenizer is not None + res = dict( + segments=self.all_segments, language=self.language, + text=_tokenizer.decode( + self.all_tokens[len(self.initial_prompt_tokens):])) + self.latest = None + return res + + def restore(self, offset: int): + processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE + while len(self.all_segments) > 0 and ( + self.all_segments[-1]["start"] >= seconds + if len(self.all_segments) == 1 else + self.all_segments[-2]["end"] > seconds): + rewriting = self.all_segments.pop() + processing += len(rewriting["tokens"]) + self.all_tokens = self.all_tokens[:len(self.all_tokens) - processing] + if len(self.all_segments) > 0 and ( + self.all_segments[-1]["start"] < seconds and + self.all_segments[-1]["end"] >= seconds): + self.seek = round( + self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) + else: + self.seek = offset + + +class InMemoryAudio(AudioFile): + dft_pad = True + + +def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: + if isinstance(audio, str): + return InMemoryAudio(fname=audio).sequential() + if isinstance(audio, np.dtype): + return torch.from_numpy(audio) + return audio + + +def transcribe( + model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + **kw): + return Transcriber(model, **kw)(audio_tensor(audio)) def cli(): diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b138..7bec242 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -1,9 +1,20 @@ import json import os +import pathlib import re import sys +import time import zlib -from typing import Callable, List, Optional, TextIO +from typing import ( + Callable, + List, + Optional, + TextIO, + Union, + TypeVar, + Generic, + Any +) system_encoding = sys.getdefaultencoding() @@ -21,11 +32,19 @@ else: return string +PathType = Union[str, pathlib.Path] + + def exact_div(x, y): assert x % y == 0 return x // y +# https://stackoverflow.com/a/17511341/3476782 +def ceildiv(a: Union[int, float], b: Union[int, float]) -> int: + return int(-(a // -b)) + + def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: @@ -67,6 +86,18 @@ def format_timestamp( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) +def hms(sec: float) -> str: + trim = sec < 3600 + h = "" if trim else str(int(sec) // 3600) + ":" + m_fill = " " if trim else "0" + m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":" + s = str(int(sec) % 60).rjust(2, '0') + "." + c = str(round((sec % 1) * 100)).rjust(2, '0') + return h + m + s + c + +def tod(seconds: float) -> str: + return time.strftime("%H:%M:%S", time.localtime(seconds)) + def get_start(segments: List[dict]) -> Optional[float]: return next( @@ -314,3 +345,42 @@ def get_writer( return write_all return writers[output_format](output_dir) + + +T = TypeVar("T") + +# boilerplate for property with _{name} storage and passthrough getter/setter +class PassthroughProperty(Generic[T]): + def __init__(self, default: T): + self.value = default + + f: Optional[Callable[[Any, T], None]] = None + def setter(self, f: Callable[[Any, T], None]): + self.f = f + return self + + g: Optional[property] = None + def property(self, g: Callable[[Any], T]): + self.g = property(g) + return self + +class PassthroughPropertyDefaults(type): + def __new__(cls, clsname, bases, attrs): + def closure(f, v): + def prop(self): + return getattr(self, v) + def setter(self, value): + setattr(self, v, value) + prop.__name__ = setter.__name__ = f + return property(prop), setter + + updates = {} + for k, v in attrs.items(): + if not isinstance(v, PassthroughProperty): + continue + private = "_" + k + updates[private] = v.value + getter, setter = closure(k, private) + updates[k] = (v.g or getter).setter(v.f or setter) + return super().__new__(cls, clsname, bases, {**attrs, **updates}) + From b4fd954955a8ff2a36e2c00222ecc875dd9c230c Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jul 2024 16:14:37 -0600 Subject: [PATCH 02/10] progress bar support and buffered cli option --- whisper/buffer.py | 18 +++++- whisper/transcribe.py | 134 ++++++++++++++++++++++++++++++------------ 2 files changed, 114 insertions(+), 38 deletions(-) diff --git a/whisper/buffer.py b/whisper/buffer.py index 3075db1..ab43b71 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -1,5 +1,5 @@ import numpy as np -import asyncio, pathlib, subprocess, torch +import asyncio, pathlib, subprocess, torch, json from .audio import ( SAMPLE_RATE, @@ -271,3 +271,19 @@ class AudioFile(RawAudioFile): if ps.returncode not in (None, 0): raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + @property + def duration(self): + cmd = [ + "ffprobe", + "-hide_banner", + "-show_format", + "-of", "json", + "-i", self.fname, + ] + ps = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = ps.communicate() + if ps.returncode not in (None, 0): + raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + return float(json.loads(stdout)['format']['duration']) + diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d648196..17ec3cc 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,12 +2,13 @@ import argparse import os import traceback import warnings +import asyncio from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict from dataclasses import dataclass import numpy as np import torch -import tqdm # TODO +import tqdm from .audio import ( FRAMES_PER_SECOND, @@ -32,7 +33,7 @@ from .utils import ( PassthroughProperty, PassthroughPropertyDefaults, ) -from .buffer import AudioFile +from .buffer import ArrayStream, AudioFile if TYPE_CHECKING: from .model import Whisper @@ -49,7 +50,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' punctuation: str = prefix + postfix - verbose: bool = False + verbose: Optional[bool] = None _decode_options: dict = {} decode_props: Tuple[str, ...] = ("fp16", "language", "task") @@ -76,10 +77,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.dtype = torch.float16 if value else torch.float32 self.fp16device() - @PassthroughProperty(None).setter - def model(self, value: "Whisper") -> None: + @PassthroughProperty[Optional["Whisper"]](None).setter + def model(self, value: Optional["Whisper"]) -> None: self._model = value - self.device = value.device + self.device = None if value is None else value.device self.input_stride = exact_div( N_FRAMES, self.model.dims.n_audio_ctx ) # mel frames per output token: 2 @@ -207,7 +208,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _tokenizer: Optional[Tokenizer] = None _tokenizer_cache: Dict[str, Tokenizer] = {} @property - def tokenizer(self) -> Optional[Tokenizer]: + def tokenizer(self) -> Tokenizer: if self._tokenizer is None: lang = self.language if self._language is not None: @@ -221,8 +222,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): task=self.task, ) return self._tokenizer - if lang is None: - return None + assert lang is not None if lang not in self._tokenizer_cache: self._tokenizer_cache[lang] = self.get_tokenizer( self.model.is_multilingual, @@ -247,9 +247,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( " " + self.initial_prompt.strip()) - if self._tokenizer is not None: - self._initial_prompt_tokens = \ - self._initial_prompt_cache[tokenizer] + self._initial_prompt_tokens = \ + self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens @@ -275,8 +274,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): hallucination_silence_threshold: Optional[float] = None, **decode_options): self.model = model - if verbose is not None: - self.verbose = verbose + self.verbose = verbose self.temperature = temperature self.compression_ratio_threshold = compression_ratio_threshold self.logprob_threshold = logprob_threshold @@ -319,20 +317,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): needs_fallback = False # silence if not needs_fallback: break + assert decode_result is not None return decode_result def new_segment( self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult) -> dict: _tokens = tokens.tolist() - _tokenizer = self.tokenizer - assert _tokenizer is not None - text_tokens = [token for token in _tokens if token < _tokenizer.eot] + text_tokens = [token for token in _tokens if token < self.tokenizer.eot] return { "seek": self.seek, "start": start, "end": end, - "text": _tokenizer.decode(text_tokens), + "text": self.tokenizer.decode(text_tokens), "tokens": _tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, @@ -371,8 +368,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self, current_segments: List[dict], segment_size: int, single_timestamp_ending: bool, tokens: torch.Tensor, timestamp_tokens: torch.Tensor, result: DecodingResult): - _tokenizer = self.tokenizer - assert _tokenizer is not None consecutive = torch.where( timestamp_tokens[:-1] & timestamp_tokens[1:])[0] consecutive.add_(1) @@ -387,10 +382,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = ( sliced_tokens[0].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) end_timestamp_pos = ( sliced_tokens[-1].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) current_segments.append( self.new_segment( start=self.time_offset + \ @@ -412,17 +407,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # timestamp last_timestamp_pos = ( tokens[last_slice - 1].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) self.seek += last_timestamp_pos * self.input_stride else: duration = segment_size * HOP_LENGTH / SAMPLE_RATE timestamps = tokens[timestamp_tokens.nonzero().flatten()] if len(timestamps) > 0 and \ - timestamps[-1].item() != _tokenizer.timestamp_begin: + timestamps[-1].item() != self.tokenizer.timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last # one. last_timestamp_pos = \ - timestamps[-1].item() - _tokenizer.timestamp_begin + timestamps[-1].item() - self.tokenizer.timestamp_begin duration = last_timestamp_pos * self.time_precision current_segments.append(self.new_segment( @@ -569,10 +564,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): current_segments: List[dict] = [] tokens = torch.tensor(result.tokens) - _tokenizer = self.tokenizer - assert _tokenizer is not None timestamp_tokens: torch.Tensor = tokens.ge( - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) single_timestamp_ending = ( timestamp_tokens[-2:].tolist() == [False, True]) @@ -615,19 +608,22 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # do not feed the prompt tokens if a high temperature was used self.prompt_reset_since = len(self.all_tokens) + self.reporthook() + if single_pass: break - _tokenizer = self.tokenizer - assert _tokenizer is not None - res = dict( + self.result = dict( segments=self.all_segments, language=self.language, - text=_tokenizer.decode( + text=self.tokenizer.decode( self.all_tokens[len(self.initial_prompt_tokens):])) self.latest = None - return res + return self.result - def restore(self, offset: int): + def reporthook(self) -> None: + pass + + def restore(self, offset: int) -> None: processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE while len(self.all_segments) > 0 and ( self.all_segments[-1]["start"] >= seconds @@ -652,16 +648,78 @@ class InMemoryAudio(AudioFile): def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: if isinstance(audio, str): return InMemoryAudio(fname=audio).sequential() - if isinstance(audio, np.dtype): + if isinstance(audio, np.ndarray): return torch.from_numpy(audio) return audio +class MinimalTranscriber(Transcriber): + exact: bool = True + chlen: float = CHUNK_LENGTH + async def process(self, stream: ArrayStream, **kw) -> dict: + data = await stream.request(self.chlen, self.exact) + while data.shape[-1] > 0: + self(data, stream.offset, True) + t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \ + / FRAMES_PER_SECOND + CHUNK_LENGTH + data = await stream.request(t, self.exact) + return self.result + + +class ProgressTranscriber(MinimalTranscriber): + def __init__(self, *a, duration: Optional[float] = None, **kw): + super().__init__(*a, **kw) + self.duration, self.progress = duration, 0 + + def __call__(self, *a, **kw) -> dict: + if self._pbar is None: + try: + return super().__call__(*a, **kw) + finally: + self.close() + else: + return super().__call__(*a, **kw) + + @PassthroughProperty(None).property + def pbar(self): + if self._pbar is None: + n = self.latest.shape[-1] if self.duration is None \ + else -int(self.duration * -FRAMES_PER_SECOND) + self._pbar = tqdm.tqdm( + total=n, unit="frames", disable=self.verbose is not False) + self._pbar.__enter__() + return self._pbar + + def reporthook(self) -> None: + update_to = min(self._seek, self.frame_offset + self.latest.shape[-1]) + self.pbar.update(update_to - self.progress) + self.progress = update_to + + def close(self): + self.pbar.__exit__(None, None, None) + + async def process(self, stream: ArrayStream, **kw) -> dict: + self.pbar + try: + return await super().process(stream, **kw) + finally: + self.close() + + async def progressive(self, stream: AudioFile, **kw) -> dict: + self.duration = stream.duration + return await self.process(stream, **kw) + + def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): - return Transcriber(model, **kw)(audio_tensor(audio)) + return ProgressTranscriber(model, **kw)(audio_tensor(audio)) + + +def buffered_transcribe(model: "Whisper", audio: str, **kw): + transcriber = ProgressTranscriber(model, **kw) + return asyncio.run(transcriber.progressive(AudioFile(fname=audio))) def cli(): @@ -712,6 +770,7 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once") # fmt: on args = parser.parse_args().__dict__ @@ -741,6 +800,7 @@ def cli(): from . import load_model model = load_model(model_name, device=device, download_root=model_dir) + transcriber = buffered_transcribe if args.pop("buffered") else transcribe writer = get_writer(output_format, output_dir) word_options = [ @@ -760,7 +820,7 @@ def cli(): writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): try: - result = transcribe(model, audio_path, temperature=temperature, **args) + result = transcriber(model, audio_path, temperature=temperature, **args) writer(result, audio_path, **writer_args) except Exception as e: traceback.print_exc() From e0704ddeba84a475ff09674fea043bfe8eb25b48 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jul 2024 16:24:14 -0600 Subject: [PATCH 03/10] add parameter documentation back in --- whisper/transcribe.py | 91 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 17ec3cc..0964c8a 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -273,6 +273,74 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, **decode_options): + """ + Transcribe an audio file using Whisper + + Parameters + ---------- + model: Whisper + The Whisper model instance + + verbose: bool + Whether to display the text being decoded to the console. If True, + displays all the details, If False, displays minimal details. If + None, does not display anything + + temperature: Union[float, Tuple[float, ...]] + Temperature for sampling. It can be a tuple of temperatures, which + will be successively used upon failures according to either + `compression_ratio_threshold` or `logprob_threshold`. + + compression_ratio_threshold: float + If the gzip compression ratio is above this value, treat as failed + + logprob_threshold: float + If the average log probability over sampled tokens is below this + value, treat as failed + + no_speech_threshold: float + If the no_speech probability is higher than this value AND the + average log probability over sampled tokens is below + `logprob_threshold`, consider the segment as silent + + condition_on_previous_text: bool + if True, the previous output of the model is provided as a prompt + for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a + failure loop, such as repetition looping or timestamps going out of + sync. + + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and + dynamic time warping, and include the timestamps for each word in + each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the + next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the + previous word + + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can + be used to provide, or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it + more likely to predict those word correctly. + + decode_options: dict + Keyword arguments to construct `DecodingOptions` instances + + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) + of clips to process. The last end timestamp defaults to the end of + the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this + threshold (in seconds) when a possible hallucination is detected + """ self.model = model self.verbose = verbose self.temperature = temperature @@ -523,6 +591,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.latest, self.frame_offset = mel, offset content_frames = mel.shape[-1] - N_FRAMES + offset content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end while self.clip_idx < len(self.seek_clips): seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] seek_clip_end = content_frames if seek_clip_end is None else \ @@ -685,6 +757,8 @@ class ProgressTranscriber(MinimalTranscriber): if self._pbar is None: n = self.latest.shape[-1] if self.duration is None \ else -int(self.duration * -FRAMES_PER_SECOND) + # show the progress bar when verbose is False + # (if True, transcribed text will be printed) self._pbar = tqdm.tqdm( total=n, unit="frames", disable=self.verbose is not False) self._pbar.__enter__() @@ -714,6 +788,23 @@ def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): + """ + Transcribe an audio file using Whisper + + Parameters + ---------- + model: Whisper + The Whisper model instance + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level + details ("segments"), and the spoken language ("language"), which is + detected when `decode_options["language"]` is None. + """ return ProgressTranscriber(model, **kw)(audio_tensor(audio)) From 0621ed8094d9d72611fa2cf64332e268784fbfe0 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jul 2024 19:07:06 -0600 Subject: [PATCH 04/10] pre-commit formatting --- whisper/batching.py | 55 ++++--- whisper/buffer.py | 141 ++++++++++------- whisper/transcribe.py | 353 ++++++++++++++++++++++++------------------ whisper/utils.py | 24 ++- 4 files changed, 328 insertions(+), 245 deletions(-) diff --git a/whisper/batching.py b/whisper/batching.py index 684ecba..d8e0e12 100644 --- a/whisper/batching.py +++ b/whisper/batching.py @@ -1,15 +1,19 @@ -import torch -import numpy as np -from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import Generic, TypeVar, Union +import numpy as np +import torch + A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor]) + class ArrayWrapper(Generic[A]): pass + ArrayTypes = Union[A, ArrayWrapper[A]] + class LoopbackIterator(Generic[A]): async def iter(self): raise NotImplementedError @@ -23,14 +27,17 @@ class LoopbackIterator(Generic[A]): self.__aiter__() return await anext(self._iter) + async def empty(): return yield + class Unwrap(LoopbackIterator): _initial: Union[ArrayTypes, Awaitable[ArrayTypes]] started: bool iterator: AsyncIterable[ArrayTypes] + def __init__(self, iterator: AsyncIterable[ArrayTypes]): while isinstance(iterator, PassthroughTransform): iterator = iterator.handoff() @@ -74,13 +81,14 @@ class Unwrap(LoopbackIterator): @property async def concat(self): - return np.concatenate if isinstance(await self.dtype, np.dtype) \ - else torch.cat + return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat + class PassthroughTransform(LoopbackIterator): def handoff(self) -> AsyncIterable[ArrayTypes]: raise NotImplementedError + class BoxedIterator(PassthroughTransform): def __init__(self, iterator): self.iterator = iterator @@ -99,6 +107,7 @@ class BoxedIterator(PassthroughTransform): if self.flag != flag: raise Exception("source can only be used by one iterator") + def LookAlong(axis: int): assert axis >= 0 empties = (slice(None),) * axis @@ -119,10 +128,9 @@ def LookAlong(axis: int): return LookAlong + class PassthroughMap(PassthroughTransform): - def __init__( - self, apply: Callable[[A], ArrayTypes], - iterator: AsyncIterator[A]): + def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]): self.iterator, self.apply = iterator, apply def handoff(self) -> AsyncIterator[A]: @@ -132,6 +140,7 @@ class PassthroughMap(PassthroughTransform): async for i in self.iterator: yield self.apply(i) + class Group: def __init__(self, concat, axis=-1): self.concat = concat @@ -155,16 +164,18 @@ class Group: if taking == amount or not exact: self.shape += amount - taking self.consumed = 0 - res = self.concat([self.holding[0][start:]] + [ - i.value for i in self.holding[1 : i + 1]]) - self.holding = self.holding[i + 1:] + res = self.concat( + [self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]] + ) + self.holding = self.holding[i + 1 :] return res if i == 0: - return self.holding[0][start:self.consumed] + return self.holding[0][start : self.consumed] res = self.concat( - [self.holding[0][start:]] + - [i.value for i in self.holding[1 : i]] + - [self.holding[i][:self.consumed]]) + [self.holding[0][start:]] + + [i.value for i in self.holding[1:i]] + + [self.holding[i][: self.consumed]] + ) self.holding = self.holding[i:] return res @@ -175,10 +186,12 @@ class Group: self.holding = [] return res + class Taken: def take(self, *a, **kw): raise Exception("batch queue moved") + class Batcher(PassthroughTransform): def __init__(self, iterator, size, axis=-1, exact=False): assert isinstance(size, int) and size > 0 @@ -192,14 +205,19 @@ class Batcher(PassthroughTransform): return lambda tensors: f(tensors, self.axis) _iterator = None + async def iterator(self): if self._iterator is None: - self.axis = len(await self.preview.shape) + self._axis \ - if self._axis < 0 else self._axis + self.axis = ( + len(await self.preview.shape) + self._axis + if self._axis < 0 + else self._axis + ) if not hasattr(self, "group"): self.group = Group(await self.concat()) self._iterator = PassthroughMap( - LookAlong(self.axis), BoxedIterator(self.preview)) + LookAlong(self.axis), BoxedIterator(self.preview) + ) return self._iterator def handoff(self): @@ -219,4 +237,3 @@ class Batcher(PassthroughTransform): return self.group.all() raise return self.group.take(self.size, self.exact) - diff --git a/whisper/buffer.py b/whisper/buffer.py index ab43b71..2dd27b4 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -1,18 +1,16 @@ +import asyncio +import json +import subprocess +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine +from typing import IO, Optional, Union + import numpy as np -import asyncio, pathlib, subprocess, torch, json +import torch -from .audio import ( - SAMPLE_RATE, - N_FFT, - HOP_LENGTH, - N_FRAMES, - mel_filters, -) - -from .utils import PathType, ceildiv +from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters from .batching import Batcher -from typing import Optional, Union, IO, Tuple, Any, Type -from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable +from .utils import PathType, ceildiv + class AudioSink: def __init__(self, *, rate: int = SAMPLE_RATE, **kw): @@ -25,11 +23,19 @@ class AudioSink: def write(self, data): raise NotImplementedError + class ArrayStream(AudioSink): q: asyncio.Queue + def __init__( - self, *, device: Optional[Union[str, torch.device]] = None, - batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw): + self, + *, + device: Optional[Union[str, torch.device]] = None, + batch: int = 1, + n_mels: int = 80, + capacity: int = 1_000_000, + **kw, + ): super().__init__(**kw) self.q = asyncio.Queue(capacity) self.finished = asyncio.Event() @@ -43,6 +49,7 @@ class ArrayStream(AudioSink): return torch.zeros(shape, dtype=torch.float32, device=self.device) write_blockable: bool = True + def write(self, data: bytes) -> Optional[Coroutine]: if self.write_blockable: return self.q.put(data) @@ -51,11 +58,9 @@ class ArrayStream(AudioSink): return None def load(self, data: bytes) -> np.ndarray: - return np.frombuffer( - data, np.int16).flatten().astype(np.float32) / 32768.0 + return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0 - async def loader(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[np.ndarray]: + async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]: async for data in iterator: yield self.load(data) @@ -64,7 +69,8 @@ class ArrayStream(AudioSink): while not self.finished.is_set(): getter = asyncio.create_task(self.q.get()) done, pending = await asyncio.wait( - (waiter, getter), return_when=asyncio.FIRST_COMPLETED) + (waiter, getter), return_when=asyncio.FIRST_COMPLETED + ) if getter in done: yield getter.result() while not self.q.empty(): @@ -78,8 +84,10 @@ class ArrayStream(AudioSink): pass loading: Optional[Batcher] = None - async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[np.ndarray]: + + async def fft_offset( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[np.ndarray]: init = self.loader(iterator) if self.loading is None else self.loading self.loading = Batcher(init, HOP_LENGTH) _iterator = aiter(self.loading) @@ -89,7 +97,7 @@ class ArrayStream(AudioSink): window = np.concatenate((window, await anext(_iterator))) except StopAsyncIteration: return - window = np.pad(window, (N_FFT // 2, 0), 'reflect') + window = np.pad(window, (N_FFT // 2, 0), "reflect") yield window async for data in _iterator: yield data @@ -101,8 +109,9 @@ class ArrayStream(AudioSink): hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH return sees[hopped:] - async def window(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[torch.Tensor]: + async def window( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[torch.Tensor]: _iterator = self.fft_offset(iterator) async for data in _iterator: _data = torch.from_numpy(data) @@ -121,19 +130,23 @@ class ArrayStream(AudioSink): def dft(self, amp: torch.Tensor) -> torch.Tensor: return torch.stft( - amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, - return_complex=True) + amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True + ) # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 log_spec_bound: Optional[torch.Tensor] = None + def transform(self, stft: torch.Tensor) -> torch.Tensor: magnitudes = stft.abs() ** 2 mel_spec = self.filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() # causes values to not precisely match the original - self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \ - else torch.maximum(log_spec.max(), self.log_spec_bound) + self.log_spec_bound = ( + log_spec.max() + if self.log_spec_bound is None + else torch.maximum(log_spec.max(), self.log_spec_bound) + ) log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec @@ -143,6 +156,7 @@ class ArrayStream(AudioSink): # dft_pad: add ending content frames to match padding from a centered STFT dft_pad: bool = False + def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor: dft_pad = self.dft_pad if dft_pad is None else dft_pad if dft_pad: @@ -170,34 +184,36 @@ class ArrayStream(AudioSink): return self.runoff() staging: Optional[Batcher] = None - async def _push(self, sec: float, exact: bool = False) -> \ - AsyncIterator[torch.Tensor]: + + async def _push( + self, sec: float, exact: bool = False + ) -> AsyncIterator[torch.Tensor]: batching = int(sec * SAMPLE_RATE // HOP_LENGTH) - init = self.window(self.buffer()) if self.staging is None \ - else self.staging + init = self.window(self.buffer()) if self.staging is None else self.staging self.staging = Batcher(init, batching, exact=exact) async for frame in self.staging: batched = batching if exact else frame.shape[-1] cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0) self.offset += cutoff - self.spectogram = torch.cat(( - self.spectogram[:, cutoff:], frame), -1) + self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1) yield self.runoff() reader: Optional[Awaitable] = None + def start(self, **kw) -> None: if self.reader is None: self.reader = asyncio.create_task(self.read(**kw)) - async def push(self, sec: float, exact: bool=False, **kw) -> \ - AsyncIterator[torch.Tensor]: + async def push( + self, sec: float, exact: bool = False, **kw + ) -> AsyncIterator[torch.Tensor]: self.start(**kw) async for i in self._push(sec, exact): yield i assert self.reader is not None await self.reader - async def request(self, sec: float, exact: bool=True, **kw) -> torch.Tensor: + async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor: try: return await anext(self.push(sec, exact)) except StopAsyncIteration: @@ -224,17 +240,17 @@ class ArrayStream(AudioSink): def all_amplitudes(self, **kw) -> np.ndarray: return asyncio.run(self.amplitudes(**kw)) + class RawAudioFile(ArrayStream): - def __init__( - self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw', - **kw): + def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw): super().__init__(**kw) self.fname = fname self.period = period fp: Optional[IO[bytes]] = None + async def read(self) -> None: - fp = open(self.fname, 'rb') if self.fp is None else self.fp + fp = open(self.fname, "rb") if self.fp is None else self.fp data = fp.read(self.period) while len(data) != 0: io_hold = self.write(data) @@ -243,28 +259,33 @@ class RawAudioFile(ArrayStream): data = fp.read(self.period) self.finished.set() + class AudioFile(RawAudioFile): - def __init__( - self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav', - **kw): + def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw): assert not subprocess.run( - ["which", "ffmpeg"], stdout=subprocess.PIPE).returncode + ["which", "ffmpeg"], stdout=subprocess.PIPE + ).returncode super().__init__(period=period or -1, fname=fname, **kw) async def read(self) -> None: cmd = [ "ffmpeg", "-nostdin", - "-threads", "0", - "-i", self.fname, - "-f", "s16le", - "-ac", "1", - "-acodec", "pcm_s16le", - "-ar", str(self.rate), - "-" + "-threads", + "0", + "-i", + self.fname, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(self.rate), + "-", ] - ps = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) self.fp = ps.stdout await super().read() _, stderr = ps.communicate() @@ -277,13 +298,13 @@ class AudioFile(RawAudioFile): "ffprobe", "-hide_banner", "-show_format", - "-of", "json", - "-i", self.fname, + "-of", + "json", + "-i", + self.fname, ] - ps = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = ps.communicate() if ps.returncode not in (None, 0): raise RuntimeError(f"Failed to load audio: {stderr.decode()}") - return float(json.loads(stdout)['format']['duration']) - + return float(json.loads(stdout)["format"]["duration"]) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0964c8a..758c0b4 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,27 +1,30 @@ import argparse +import asyncio import os import traceback import warnings -import asyncio -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch import tqdm from .audio import ( + CHUNK_LENGTH, FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, SAMPLE_RATE, - CHUNK_LENGTH, pad_or_trim, ) +from .buffer import ArrayStream, AudioFile from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer from .utils import ( + PassthroughProperty, + PassthroughPropertyDefaults, exact_div, format_timestamp, get_end, @@ -30,14 +33,12 @@ from .utils import ( optional_float, optional_int, str2bool, - PassthroughProperty, - PassthroughPropertyDefaults, ) -from .buffer import ArrayStream, AudioFile if TYPE_CHECKING: from .model import Whisper + @dataclass class LanguageHypothesis: language: Optional[str] = None @@ -45,15 +46,17 @@ class LanguageHypothesis: evidence: int = 0 last: int = -1 + class Transcriber(metaclass=PassthroughPropertyDefaults): - prefix: str = '''"'\u201c\u00bf([{-''' - postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' + prefix: str = """"'\u201c\u00bf([{-""" + postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001""" punctuation: str = prefix + postfix verbose: Optional[bool] = None _decode_options: dict = {} decode_props: Tuple[str, ...] = ("fp16", "language", "task") + @property def decode_options(self) -> dict: for k in self.decode_props: @@ -68,6 +71,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): setattr(self, k, value[k]) dtype: torch.dtype = torch.float16 + @property def fp16(self) -> bool: return self.dtype == torch.float16 @@ -93,8 +97,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self._device = value if value == torch.device("cpu"): if torch.cuda.is_available(): - warnings.warn( - "Performing inference on CPU when CUDA is available") + warnings.warn("Performing inference on CPU when CUDA is available") self.fp16device() def fp16device(self) -> None: @@ -110,6 +113,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): prev: Optional[torch.Tensor] = None _latest: Optional[torch.Tensor] = None + @PassthroughProperty[Optional[torch.Tensor]](None).setter def latest(self, value: Optional[torch.Tensor]) -> None: self.prev = self._latest @@ -117,6 +121,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _hypothesis: LanguageHypothesis = LanguageHypothesis() _language: Optional[str] + @PassthroughProperty[Optional[str]](None).property def language(self) -> Optional[str]: if self._language is not None: @@ -125,20 +130,24 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return "en" if self.verbose: print( - "Detecting language using up to the first 30 seconds." - "Use `--language` to specify the language") + "Detecting language using up to the first 30 seconds." + "Use `--language` to specify the language" + ) if self.latest is None: return None if self._seek == self._hypothesis.last: return self._hypothesis.language if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: - mel = self.latest if self.prev is None else torch.cat( - (self.prev[:self.frame_offset], self.latest), -1) + mel = ( + self.latest + if self.prev is None + else torch.cat((self.prev[: self.frame_offset], self.latest), -1) + ) self._language = self.detect_language(mel) return self._language self._hypothesis.last = self._seek or 0 self._hypothesis.since += 1 - if 2 ** self._hypothesis.evidence < self._hypothesis.since: + if 2**self._hypothesis.evidence < self._hypothesis.since: return self._hypothesis.language self._hypothesis.since = 0 guess = self.detect_language() @@ -152,22 +161,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): self._seek_clips = None if isinstance(value, str): - self._clip_timestamps = tuple(map(float, value.split(","))) \ - if value else (0,) + self._clip_timestamps = ( + tuple(map(float, value.split(","))) if value else (0,) + ) else: self._clip_timestamps = tuple(value) or (0,) _seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None + @property def seek_clips(self) -> List[Tuple[int, Optional[int]]]: if self._seek_clips is None: seek_points = tuple( - round(ts * FRAMES_PER_SECOND) - for ts in self.clip_timestamps) + (None,) + round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps + ) + (None,) self._seek_clips = list(zip(seek_points[::2], seek_points[1::2])) return self._seek_clips _seek: Optional[int] + @PassthroughProperty[Optional[int]](None).property def seek(self) -> Optional[int]: return self.seek_clips[0][0] if self._seek is None else self._seek @@ -179,25 +191,30 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if value < len(clips): self.seek = clips[value][0] - time_offset = property(lambda self: float( - self.seek * HOP_LENGTH / SAMPLE_RATE)) - window_end_time = property(lambda self: float( - (self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)) + time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE)) + window_end_time = property( + lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + ) _temperature: Union[Optional[float], Tuple[float, ...]] + @PassthroughProperty[Union[Optional[float], Tuple[float, ...]]]( - (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter + (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + ).setter def temperature(self, value: Union[Optional[float], Tuple[float, ...]]): - self._temperature = (value,) if isinstance(value, (int, float)) else ( - Transcriber._temperature if value is None else value) + self._temperature = ( + (value,) + if isinstance(value, (int, float)) + else (Transcriber._temperature if value is None else value) + ) @PassthroughProperty("transcribe").setter def task(self, value: str): self._task = value if self.word_timestamps and value == "translate": warnings.warn( - "Word-level timestamps on translations may not be " - "reliable.") + "Word-level timestamps on translations may not be " "reliable." + ) @PassthroughProperty(False).setter def word_timestamps(self, value: bool): @@ -207,6 +224,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): get_tokenizer = staticmethod(get_tokenizer) _tokenizer: Optional[Tokenizer] = None _tokenizer_cache: Dict[str, Tokenizer] = {} + @property def tokenizer(self) -> Tokenizer: if self._tokenizer is None: @@ -235,6 +253,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _initial_prompt_tokens: Optional[List[int]] = None _initial_prompt_cache: Dict[Tokenizer, List[int]] = {} + @property def initial_prompt_tokens(self) -> List[int]: if self._initial_prompt_tokens is None: @@ -246,9 +265,9 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return [] if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( - " " + self.initial_prompt.strip()) - self._initial_prompt_tokens = \ - self._initial_prompt_cache[tokenizer] + " " + self.initial_prompt.strip() + ) + self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens @@ -256,23 +275,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): last_speech_timestamp: float = 0.0 frame_offset: int = 0 all_segments: List[dict] + def __init__( - self, - model: "Whisper", - *, - verbose: Optional[bool] = None, - temperature: Union[Optional[float], Tuple[float, ...]] = None, - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = prefix, - append_punctuations: str = postfix, - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - **decode_options): + self, + model: "Whisper", + *, + verbose: Optional[bool] = None, + temperature: Union[Optional[float], Tuple[float, ...]] = None, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = prefix, + append_punctuations: str = postfix, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + **decode_options, + ): """ Transcribe an audio file using Whisper @@ -374,14 +395,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): needs_fallback = False if self.compression_ratio_threshold is not None and ( - decode_result.compression_ratio > - self.compression_ratio_threshold): + decode_result.compression_ratio > self.compression_ratio_threshold + ): needs_fallback = True # too repetitive if self.logprob_threshold is not None and ( - decode_result.avg_logprob < self.logprob_threshold): + decode_result.avg_logprob < self.logprob_threshold + ): needs_fallback = True # average log probability is too low if self.no_speech_threshold is not None and ( - decode_result.no_speech_prob > self.no_speech_threshold): + decode_result.no_speech_prob > self.no_speech_threshold + ): needs_fallback = False # silence if not needs_fallback: break @@ -389,8 +412,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return decode_result def new_segment( - self, *, start: float, end: float, tokens: torch.Tensor, - result: DecodingResult) -> dict: + self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult + ) -> dict: _tokens = tokens.tolist() text_tokens = [token for token in _tokens if token < self.tokenizer.eot] return { @@ -422,9 +445,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def is_segment_anomaly(self, segment: Optional[dict]) -> bool: if segment is None or not segment["words"]: return False - words = [ - w for w in segment["words"] - if w["word"] not in self.punctuation][:8] + words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8] score = sum(self.word_anomaly_score(w) for w in words) return score >= 3 or score + 0.01 >= len(words) @@ -433,11 +454,15 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return next((s for s in segments if s["words"]), None) def reseek( - self, current_segments: List[dict], segment_size: int, - single_timestamp_ending: bool, tokens: torch.Tensor, - timestamp_tokens: torch.Tensor, result: DecodingResult): - consecutive = torch.where( - timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + tokens: torch.Tensor, + timestamp_tokens: torch.Tensor, + result: DecodingResult, + ): + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] consecutive.add_(1) if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens @@ -449,17 +474,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = ( - sliced_tokens[0].item() - - self.tokenizer.timestamp_begin) + sliced_tokens[0].item() - self.tokenizer.timestamp_begin + ) end_timestamp_pos = ( - sliced_tokens[-1].item() - - self.tokenizer.timestamp_begin) + sliced_tokens[-1].item() - self.tokenizer.timestamp_begin + ) current_segments.append( self.new_segment( - start=self.time_offset + \ - start_timestamp_pos * self.time_precision, - end=self.time_offset + \ - end_timestamp_pos * self.time_precision, + start=self.time_offset + + start_timestamp_pos * self.time_precision, + end=self.time_offset + end_timestamp_pos * self.time_precision, tokens=sliced_tokens, result=result, ) @@ -474,31 +498,42 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # otherwise, ignore the unfinished segment and seek to the last # timestamp last_timestamp_pos = ( - tokens[last_slice - 1].item() - - self.tokenizer.timestamp_begin) + tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin + ) self.seek += last_timestamp_pos * self.input_stride else: duration = segment_size * HOP_LENGTH / SAMPLE_RATE timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and \ - timestamps[-1].item() != self.tokenizer.timestamp_begin: + if ( + len(timestamps) > 0 + and timestamps[-1].item() != self.tokenizer.timestamp_begin + ): # no consecutive timestamps but it has a timestamp; use the last # one. - last_timestamp_pos = \ - timestamps[-1].item() - self.tokenizer.timestamp_begin + last_timestamp_pos = ( + timestamps[-1].item() - self.tokenizer.timestamp_begin + ) duration = last_timestamp_pos * self.time_precision - current_segments.append(self.new_segment( + current_segments.append( + self.new_segment( start=self.time_offset, end=self.time_offset + duration, tokens=tokens, - result=result)) + result=result, + ) + ) self.seek += segment_size def timestamp( - self, current_segments: List[dict], segment_size: int, - single_timestamp_ending: bool, mel_segment: torch.Tensor, - previous_seek: int, content_frames: int) -> bool: + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + mel_segment: torch.Tensor, + previous_seek: int, + content_frames: int, + ) -> bool: add_word_timestamps( segments=current_segments, model=self.model, @@ -512,8 +547,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if not single_timestamp_ending: last_word_end = get_end(current_segments) - if last_word_end is not None and \ - last_word_end > self.time_offset: + if last_word_end is not None and last_word_end > self.time_offset: self.seek = round(last_word_end * FRAMES_PER_SECOND) # skip silence before possible hallucinations @@ -521,24 +555,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): threshold = self.hallucination_silence_threshold if not single_timestamp_ending: last_word_end = get_end(current_segments) - if last_word_end is not None and \ - last_word_end > self.time_offset: - remaining_duration = \ - self.window_end_time - last_word_end + if last_word_end is not None and last_word_end > self.time_offset: + remaining_duration = self.window_end_time - last_word_end if remaining_duration > threshold: - self.seek = round( - last_word_end * FRAMES_PER_SECOND) + self.seek = round(last_word_end * FRAMES_PER_SECOND) else: self.seek = previous_seek + segment_size # if first segment might be a hallucination, skip leading silence first_segment = self.next_words_segment(current_segments) - if first_segment is not None and self.is_segment_anomaly( - first_segment): + if first_segment is not None and self.is_segment_anomaly(first_segment): gap = first_segment["start"] - self.time_offset if gap > threshold: - self.seek = previous_seek + round( - gap * FRAMES_PER_SECOND) + self.seek = previous_seek + round(gap * FRAMES_PER_SECOND) return True # skip silence before any possible hallucination that is @@ -550,14 +579,13 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if not segment["words"]: continue if self.is_segment_anomaly(segment): - next_segment = self.next_words_segment( - current_segments[si + 1 :]) + next_segment = self.next_words_segment(current_segments[si + 1 :]) if next_segment is not None: - hal_next_start = \ - next_segment["words"][0]["start"] + hal_next_start = next_segment["words"][0]["start"] else: - hal_next_start = self.time_offset + \ - segment_size * HOP_LENGTH / SAMPLE_RATE + hal_next_start = ( + self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE + ) silence_before = ( segment["start"] - hal_last_end > threshold or segment["start"] < threshold @@ -573,8 +601,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): max(self.time_offset + 1, segment["start"]) * FRAMES_PER_SECOND ) - if content_duration - segment["end"] < \ - threshold: + if content_duration - segment["end"] < threshold: self.seek = content_frames current_segments[si:] = [] break @@ -586,19 +613,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return False def __call__( - self, mel: torch.Tensor, offset: int = 0, - single_pass: bool = False) -> dict: + self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False + ) -> dict: self.latest, self.frame_offset = mel, offset content_frames = mel.shape[-1] - N_FRAMES + offset - content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) # NOTE: This loop is obscurely flattened to make the diff readable. # A later commit should turn this into a simpler nested loop. # for seek_clip_start, seek_clip_end in seek_clips: # while seek < seek_clip_end while self.clip_idx < len(self.seek_clips): seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] - seek_clip_end = content_frames if seek_clip_end is None else \ - seek_clip_end + seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end if self.seek < seek_clip_start: self.seek = seek_clip_start if self.seek >= seek_clip_end: @@ -607,22 +632,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.clip_idx += 1 continue segment_size = min( - N_FRAMES, content_frames - self.seek, - seek_clip_end - self.seek) - mel_segment = mel[ - :, self.seek - offset : self.seek + segment_size - offset] - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to( - self.device).to(self.dtype) + N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek + ) + mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset] + mel_segment = ( + pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype) + ) - self.decode_options["prompt"] = \ - self.all_tokens[self.prompt_reset_since:] + self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :] result: DecodingResult = self.decode_with_fallback(mel_segment) if self.no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > self.no_speech_threshold - if self.logprob_threshold is not None and \ - result.avg_logprob > self.logprob_threshold: + if ( + self.logprob_threshold is not None + and result.avg_logprob > self.logprob_threshold + ): # don't skip if the logprob is high enough, despite the # no_speech_prob should_skip = False @@ -636,19 +662,27 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): current_segments: List[dict] = [] tokens = torch.tensor(result.tokens) - timestamp_tokens: torch.Tensor = tokens.ge( - self.tokenizer.timestamp_begin) - single_timestamp_ending = ( - timestamp_tokens[-2:].tolist() == [False, True]) + timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] self.reseek( - current_segments, segment_size, single_timestamp_ending, - tokens, timestamp_tokens, result) + current_segments, + segment_size, + single_timestamp_ending, + tokens, + timestamp_tokens, + result, + ) if self.word_timestamps: if self.timestamp( - current_segments, segment_size, single_timestamp_ending, - mel_segment, previous_seek, content_frames): + current_segments, + segment_size, + single_timestamp_ending, + mel_segment, + previous_seek, + content_frames, + ): continue if self.verbose: @@ -656,25 +690,29 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): start, end = segment["start"], segment["end"] text = segment["text"] line = ( - f"[{format_timestamp(start)} --> " - f"{format_timestamp(end)}] {text}") + f"[{format_timestamp(start)} --> " + f"{format_timestamp(end)}] {text}" + ) print(make_safe(line)) # if a segment is instantaneous or does not contain text, clear it for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or \ - segment["text"].strip() == "": + if segment["start"] == segment["end"] or segment["text"].strip() == "": segment["text"] = "" segment["tokens"] = [] segment["words"] = [] - self.all_segments.extend([ + self.all_segments.extend( + [ {"id": i, **segment} for i, segment in enumerate( - current_segments, start=len(self.all_segments))]) - self.all_tokens.extend([ - token for segment in current_segments - for token in segment["tokens"]]) + current_segments, start=len(self.all_segments) + ) + ] + ) + self.all_tokens.extend( + [token for segment in current_segments for token in segment["tokens"]] + ) if not self.condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used @@ -686,9 +724,12 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): break self.result = dict( - segments=self.all_segments, language=self.language, - text=self.tokenizer.decode( - self.all_tokens[len(self.initial_prompt_tokens):])) + segments=self.all_segments, + language=self.language, + text=self.tokenizer.decode( + self.all_tokens[len(self.initial_prompt_tokens) :] + ), + ) self.latest = None return self.result @@ -698,17 +739,18 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def restore(self, offset: int) -> None: processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE while len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] >= seconds - if len(self.all_segments) == 1 else - self.all_segments[-2]["end"] > seconds): + self.all_segments[-1]["start"] >= seconds + if len(self.all_segments) == 1 + else self.all_segments[-2]["end"] > seconds + ): rewriting = self.all_segments.pop() processing += len(rewriting["tokens"]) - self.all_tokens = self.all_tokens[:len(self.all_tokens) - processing] + self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing] if len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] < seconds and - self.all_segments[-1]["end"] >= seconds): - self.seek = round( - self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) + self.all_segments[-1]["start"] < seconds + and self.all_segments[-1]["end"] >= seconds + ): + self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) else: self.seek = offset @@ -728,12 +770,16 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: class MinimalTranscriber(Transcriber): exact: bool = True chlen: float = CHUNK_LENGTH + async def process(self, stream: ArrayStream, **kw) -> dict: data = await stream.request(self.chlen, self.exact) while data.shape[-1] > 0: self(data, stream.offset, True) - t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \ - / FRAMES_PER_SECOND + CHUNK_LENGTH + t = ( + self.chlen + - (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND + + CHUNK_LENGTH + ) data = await stream.request(t, self.exact) return self.result @@ -755,12 +801,16 @@ class ProgressTranscriber(MinimalTranscriber): @PassthroughProperty(None).property def pbar(self): if self._pbar is None: - n = self.latest.shape[-1] if self.duration is None \ - else -int(self.duration * -FRAMES_PER_SECOND) + n = ( + self.latest.shape[-1] + if self.duration is None + else -int(self.duration * -FRAMES_PER_SECOND) + ) # show the progress bar when verbose is False # (if True, transcribed text will be printed) self._pbar = tqdm.tqdm( - total=n, unit="frames", disable=self.verbose is not False) + total=n, unit="frames", disable=self.verbose is not False + ) self._pbar.__enter__() return self._pbar @@ -784,10 +834,7 @@ class ProgressTranscriber(MinimalTranscriber): return await self.process(stream, **kw) -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor], - **kw): +def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): """ Transcribe an audio file using Whisper diff --git a/whisper/utils.py b/whisper/utils.py index 7bec242..27a0465 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -5,16 +5,7 @@ import re import sys import time import zlib -from typing import ( - Callable, - List, - Optional, - TextIO, - Union, - TypeVar, - Generic, - Any -) +from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union system_encoding = sys.getdefaultencoding() @@ -86,15 +77,17 @@ def format_timestamp( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) + def hms(sec: float) -> str: trim = sec < 3600 h = "" if trim else str(int(sec) // 3600) + ":" m_fill = " " if trim else "0" m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":" - s = str(int(sec) % 60).rjust(2, '0') + "." - c = str(round((sec % 1) * 100)).rjust(2, '0') + s = str(int(sec) % 60).rjust(2, "0") + "." + c = str(round((sec % 1) * 100)).rjust(2, "0") return h + m + s + c + def tod(seconds: float) -> str: return time.strftime("%H:%M:%S", time.localtime(seconds)) @@ -349,28 +342,34 @@ def get_writer( T = TypeVar("T") + # boilerplate for property with _{name} storage and passthrough getter/setter class PassthroughProperty(Generic[T]): def __init__(self, default: T): self.value = default f: Optional[Callable[[Any, T], None]] = None + def setter(self, f: Callable[[Any, T], None]): self.f = f return self g: Optional[property] = None + def property(self, g: Callable[[Any], T]): self.g = property(g) return self + class PassthroughPropertyDefaults(type): def __new__(cls, clsname, bases, attrs): def closure(f, v): def prop(self): return getattr(self, v) + def setter(self, value): setattr(self, v, value) + prop.__name__ = setter.__name__ = f return property(prop), setter @@ -383,4 +382,3 @@ class PassthroughPropertyDefaults(type): getter, setter = closure(k, private) updates[k] = (v.g or getter).setter(v.f or setter) return super().__new__(cls, clsname, bases, {**attrs, **updates}) - From c09790488bfa5a2fbf9d7decbd8d5747df4f5b27 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jul 2024 19:37:18 -0600 Subject: [PATCH 05/10] simplify auto-formatting output --- whisper/transcribe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 758c0b4..069c5ed 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -212,9 +212,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def task(self, value: str): self._task = value if self.word_timestamps and value == "translate": - warnings.warn( - "Word-level timestamps on translations may not be " "reliable." - ) + warnings.warn("Word-level timestamps on translations may not be reliable.") @PassthroughProperty(False).setter def word_timestamps(self, value: bool): From 610f82ffba57540caac71147c28571ec6e362ced Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 21 Jul 2024 20:21:02 -0700 Subject: [PATCH 06/10] remove realtime-specific code --- whisper/transcribe.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 069c5ed..aaf261a 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -734,24 +734,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def reporthook(self) -> None: pass - def restore(self, offset: int) -> None: - processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE - while len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] >= seconds - if len(self.all_segments) == 1 - else self.all_segments[-2]["end"] > seconds - ): - rewriting = self.all_segments.pop() - processing += len(rewriting["tokens"]) - self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing] - if len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] < seconds - and self.all_segments[-1]["end"] >= seconds - ): - self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) - else: - self.seek = offset - class InMemoryAudio(AudioFile): dft_pad = True From 247391a2afc5ea23e4d19f9b94d88a30b07d0523 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Mon, 22 Jul 2024 13:16:53 -0700 Subject: [PATCH 07/10] language detection patch and test --- tests/test_transcribe.py | 78 ++++++++++++++++++++++++++++++++++++++++ whisper/buffer.py | 1 - whisper/transcribe.py | 35 +++++++++++++----- 3 files changed, 104 insertions(+), 10 deletions(-) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 599221a..39108d7 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -4,7 +4,9 @@ import pytest import torch import whisper +from whisper.audio import CHUNK_LENGTH from whisper.tokenizer import get_tokenizer +from whisper.transcribe import Transcriber @pytest.mark.parametrize("model_name", whisper.available_models()) @@ -40,3 +42,79 @@ def test_transcribe(model_name: str): timing_checked = True assert timing_checked + + +class MockTokenizer: + def __init__(self, language, **kw): + self.language, self._kw = language, kw + for k, v in kw.items(): + setattr(self, k, v) + + def encode(self, prompt): + return [self.language, self, prompt] + + +class OnDemand: + def __init__(self, seq=(), relative=True): + self.seq, self.relative = seq, relative + self.prev, self.given = 0, 0 + + def __getitem__(self, key): + _key = self.given if self.relative else key + self.prev = ( + self.seq[_key] + if _key < len(self.seq) + else int(input(f"lang @ {_key}: ") or self.prev) + ) + self.given += 1 + return self.prev + + def __len__(self): + return CHUNK_LENGTH + 2 if self.relative else len(self.seq) + + +class TranscriberTest(Transcriber): + sample = object() + dtype = torch.float32 + model = type( + "MockModel", + (), + {"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")}, + )() + _seek = 0 + + def __init__(self, seq=None): + super().__init__(self.model, initial_prompt="") + self.seq = OnDemand(seq or ()) + self.result = [] + self.latest = torch.zeros((0,)) + for i in range(len(self.seq)): + self._seek = i + self.frame_offset = max(0, i + 1 - CHUNK_LENGTH) + res = self.initial_prompt_tokens + assert res[0] == self.seq.prev + self.result.append(res[1:]) + if seq is None: + print(res) + + def detect_language(self, mel=None): + self.result.append([self.sample, mel]) + return self.seq[self._seek] + + def get_tokenizer(self, multilingual, language, **kw): + return MockTokenizer(language, **{"multilingual": multilingual, **kw}) + + @property + def rle(self): + res = [] + for i, *j in self.result: + if i is self.sample: + res.append(0) + else: + res[-1] += 1 + return res + + +def test_language(): + res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle + assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2] diff --git a/whisper/buffer.py b/whisper/buffer.py index 2dd27b4..5229ad8 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -133,7 +133,6 @@ class ArrayStream(AudioSink): amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True ) - # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 log_spec_bound: Optional[torch.Tensor] = None def transform(self, stft: torch.Tensor) -> torch.Tensor: diff --git a/whisper/transcribe.py b/whisper/transcribe.py index aaf261a..dbc81bd 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -4,6 +4,7 @@ import os import traceback import warnings from dataclasses import dataclass +from math import ceil from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np @@ -147,15 +148,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return self._language self._hypothesis.last = self._seek or 0 self._hypothesis.since += 1 - if 2**self._hypothesis.evidence < self._hypothesis.since: + if 2**self._hypothesis.evidence > self._hypothesis.since: return self._hypothesis.language self._hypothesis.since = 0 guess = self.detect_language() if guess == self._hypothesis.language: self._hypothesis.evidence += 1 - self._hypothesis.language = guess - self._hypothesis.evidence = 1 - return None + else: + self._hypothesis.language = guess + self._hypothesis.evidence = 0 + return guess @PassthroughProperty[Union[str, List[float], Tuple[float]]]((0,)).setter def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): @@ -257,18 +259,34 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if self._initial_prompt_tokens is None: if self.initial_prompt is None: self._initial_prompt_tokens = [] + elif self.language is None: + return [] else: tokenizer = self.tokenizer - if tokenizer is None: - return [] if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( " " + self.initial_prompt.strip() ) - self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] + if self._tokenizer is not None: + self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens + _initial_tokens: int = 0 + _initial_finalized: bool = False + _all_tokens: Optional[list] = None + + @property + def all_tokens(self): + if self._all_tokens is None: + self._all_tokens = [] + if not self._initial_finalized: + initial = self.initial_prompt_tokens + self._all_tokens = initial + self._all_tokens[self._initial_tokens :] + self._initial_tokens = len(initial) + self._initial_finalized = self._initial_prompt_tokens is not None + return self._all_tokens + prompt_reset_since: int = 0 last_speech_timestamp: float = 0.0 frame_offset: int = 0 @@ -375,7 +393,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.hallucination_silence_threshold = hallucination_silence_threshold self.decode_options = decode_options - self.all_tokens = self.initial_prompt_tokens[:] self.all_segments = [] def decode_with_fallback(self, segment: torch.Tensor) -> DecodingResult: @@ -784,7 +801,7 @@ class ProgressTranscriber(MinimalTranscriber): n = ( self.latest.shape[-1] if self.duration is None - else -int(self.duration * -FRAMES_PER_SECOND) + else ceil(self.duration * FRAMES_PER_SECOND) ) # show the progress bar when verbose is False # (if True, transcribed text will be printed) From 092cb3409e5a0ba367a99de7e67834c8c8435370 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Mon, 22 Jul 2024 13:40:12 -0700 Subject: [PATCH 08/10] detect language based on available frames not seek --- whisper/transcribe.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index dbc81bd..f66a146 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -122,6 +122,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _hypothesis: LanguageHypothesis = LanguageHypothesis() _language: Optional[str] + _language_detection_warned: bool = False @PassthroughProperty[Optional[str]](None).property def language(self) -> Optional[str]: @@ -129,15 +130,18 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return self._language if not self.model.is_multilingual: return "en" - if self.verbose: + if self.verbose and not self._language_detection_warned: print( "Detecting language using up to the first 30 seconds." "Use `--language` to specify the language" ) + self._language_detection_warned = True if self.latest is None: return None - if self._seek == self._hypothesis.last: + available = self.frame_offset + self.latest.shape[-1] + if available == self._hypothesis.last: return self._hypothesis.language + self._hypothesis.last = available if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: mel = ( self.latest @@ -146,7 +150,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): ) self._language = self.detect_language(mel) return self._language - self._hypothesis.last = self._seek or 0 self._hypothesis.since += 1 if 2**self._hypothesis.evidence > self._hypothesis.since: return self._hypothesis.language From 1caba7d5d46d8ee91376c7f25d376cf450ad5fde Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Mon, 22 Jul 2024 16:14:30 -0700 Subject: [PATCH 09/10] clarify transcription parameter --- jfk.json | 1 + jfk.srt | 4 ++++ jfk.tsv | 2 ++ jfk.txt | 1 + jfk.vtt | 5 +++++ whisper/transcribe.py | 7 ++++--- 6 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 jfk.json create mode 100644 jfk.srt create mode 100644 jfk.tsv create mode 100644 jfk.txt create mode 100644 jfk.vtt diff --git a/jfk.json b/jfk.json new file mode 100644 index 0000000..360fc47 --- /dev/null +++ b/jfk.json @@ -0,0 +1 @@ +{"segments": [{"id": 0, "seek": 0, "start": 0.0, "end": 11.0, "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.", "tokens": [50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50914], "temperature": 0.0, "avg_logprob": -0.20427462032863072, "compression_ratio": 1.3544303797468353, "no_speech_prob": 0.04382958635687828}], "language": "en", "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."} \ No newline at end of file diff --git a/jfk.srt b/jfk.srt new file mode 100644 index 0000000..a2c8946 --- /dev/null +++ b/jfk.srt @@ -0,0 +1,4 @@ +1 +00:00:00,000 --> 00:00:11,000 +And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + diff --git a/jfk.tsv b/jfk.tsv new file mode 100644 index 0000000..ad86260 --- /dev/null +++ b/jfk.tsv @@ -0,0 +1,2 @@ +start end text +0 11000 And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/jfk.txt b/jfk.txt new file mode 100644 index 0000000..64b97d3 --- /dev/null +++ b/jfk.txt @@ -0,0 +1 @@ +And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/jfk.vtt b/jfk.vtt new file mode 100644 index 0000000..ae50503 --- /dev/null +++ b/jfk.vtt @@ -0,0 +1,5 @@ +WEBVTT + +00:00.000 --> 00:11.000 +And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + diff --git a/whisper/transcribe.py b/whisper/transcribe.py index f66a146..a215833 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -769,14 +769,15 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: class MinimalTranscriber(Transcriber): exact: bool = True - chlen: float = CHUNK_LENGTH + # amount of time per chunk that is considered in-context + contextualized: float = CHUNK_LENGTH async def process(self, stream: ArrayStream, **kw) -> dict: - data = await stream.request(self.chlen, self.exact) + data = await stream.request(CHUNK_LENGTH, self.exact) while data.shape[-1] > 0: self(data, stream.offset, True) t = ( - self.chlen + self.contextualized - (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND + CHUNK_LENGTH ) From 41ca6713383c50a32f19e7720acacfb7955530c9 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Mon, 22 Jul 2024 16:15:32 -0700 Subject: [PATCH 10/10] remove accidentally added test output --- jfk.json | 1 - jfk.srt | 4 ---- jfk.tsv | 2 -- jfk.txt | 1 - jfk.vtt | 5 ----- 5 files changed, 13 deletions(-) delete mode 100644 jfk.json delete mode 100644 jfk.srt delete mode 100644 jfk.tsv delete mode 100644 jfk.txt delete mode 100644 jfk.vtt diff --git a/jfk.json b/jfk.json deleted file mode 100644 index 360fc47..0000000 --- a/jfk.json +++ /dev/null @@ -1 +0,0 @@ -{"segments": [{"id": 0, "seek": 0, "start": 0.0, "end": 11.0, "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.", "tokens": [50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50914], "temperature": 0.0, "avg_logprob": -0.20427462032863072, "compression_ratio": 1.3544303797468353, "no_speech_prob": 0.04382958635687828}], "language": "en", "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."} \ No newline at end of file diff --git a/jfk.srt b/jfk.srt deleted file mode 100644 index a2c8946..0000000 --- a/jfk.srt +++ /dev/null @@ -1,4 +0,0 @@ -1 -00:00:00,000 --> 00:00:11,000 -And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. - diff --git a/jfk.tsv b/jfk.tsv deleted file mode 100644 index ad86260..0000000 --- a/jfk.tsv +++ /dev/null @@ -1,2 +0,0 @@ -start end text -0 11000 And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/jfk.txt b/jfk.txt deleted file mode 100644 index 64b97d3..0000000 --- a/jfk.txt +++ /dev/null @@ -1 +0,0 @@ -And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/jfk.vtt b/jfk.vtt deleted file mode 100644 index ae50503..0000000 --- a/jfk.vtt +++ /dev/null @@ -1,5 +0,0 @@ -WEBVTT - -00:00.000 --> 00:11.000 -And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. -