diff --git a/whisper/batching.py b/whisper/batching.py index 684ecba..d8e0e12 100644 --- a/whisper/batching.py +++ b/whisper/batching.py @@ -1,15 +1,19 @@ -import torch -import numpy as np -from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import Generic, TypeVar, Union +import numpy as np +import torch + A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor]) + class ArrayWrapper(Generic[A]): pass + ArrayTypes = Union[A, ArrayWrapper[A]] + class LoopbackIterator(Generic[A]): async def iter(self): raise NotImplementedError @@ -23,14 +27,17 @@ class LoopbackIterator(Generic[A]): self.__aiter__() return await anext(self._iter) + async def empty(): return yield + class Unwrap(LoopbackIterator): _initial: Union[ArrayTypes, Awaitable[ArrayTypes]] started: bool iterator: AsyncIterable[ArrayTypes] + def __init__(self, iterator: AsyncIterable[ArrayTypes]): while isinstance(iterator, PassthroughTransform): iterator = iterator.handoff() @@ -74,13 +81,14 @@ class Unwrap(LoopbackIterator): @property async def concat(self): - return np.concatenate if isinstance(await self.dtype, np.dtype) \ - else torch.cat + return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat + class PassthroughTransform(LoopbackIterator): def handoff(self) -> AsyncIterable[ArrayTypes]: raise NotImplementedError + class BoxedIterator(PassthroughTransform): def __init__(self, iterator): self.iterator = iterator @@ -99,6 +107,7 @@ class BoxedIterator(PassthroughTransform): if self.flag != flag: raise Exception("source can only be used by one iterator") + def LookAlong(axis: int): assert axis >= 0 empties = (slice(None),) * axis @@ -119,10 +128,9 @@ def LookAlong(axis: int): return LookAlong + class PassthroughMap(PassthroughTransform): - def __init__( - self, apply: Callable[[A], ArrayTypes], - iterator: AsyncIterator[A]): + def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]): self.iterator, self.apply = iterator, apply def handoff(self) -> AsyncIterator[A]: @@ -132,6 +140,7 @@ class PassthroughMap(PassthroughTransform): async for i in self.iterator: yield self.apply(i) + class Group: def __init__(self, concat, axis=-1): self.concat = concat @@ -155,16 +164,18 @@ class Group: if taking == amount or not exact: self.shape += amount - taking self.consumed = 0 - res = self.concat([self.holding[0][start:]] + [ - i.value for i in self.holding[1 : i + 1]]) - self.holding = self.holding[i + 1:] + res = self.concat( + [self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]] + ) + self.holding = self.holding[i + 1 :] return res if i == 0: - return self.holding[0][start:self.consumed] + return self.holding[0][start : self.consumed] res = self.concat( - [self.holding[0][start:]] + - [i.value for i in self.holding[1 : i]] + - [self.holding[i][:self.consumed]]) + [self.holding[0][start:]] + + [i.value for i in self.holding[1:i]] + + [self.holding[i][: self.consumed]] + ) self.holding = self.holding[i:] return res @@ -175,10 +186,12 @@ class Group: self.holding = [] return res + class Taken: def take(self, *a, **kw): raise Exception("batch queue moved") + class Batcher(PassthroughTransform): def __init__(self, iterator, size, axis=-1, exact=False): assert isinstance(size, int) and size > 0 @@ -192,14 +205,19 @@ class Batcher(PassthroughTransform): return lambda tensors: f(tensors, self.axis) _iterator = None + async def iterator(self): if self._iterator is None: - self.axis = len(await self.preview.shape) + self._axis \ - if self._axis < 0 else self._axis + self.axis = ( + len(await self.preview.shape) + self._axis + if self._axis < 0 + else self._axis + ) if not hasattr(self, "group"): self.group = Group(await self.concat()) self._iterator = PassthroughMap( - LookAlong(self.axis), BoxedIterator(self.preview)) + LookAlong(self.axis), BoxedIterator(self.preview) + ) return self._iterator def handoff(self): @@ -219,4 +237,3 @@ class Batcher(PassthroughTransform): return self.group.all() raise return self.group.take(self.size, self.exact) - diff --git a/whisper/buffer.py b/whisper/buffer.py index ab43b71..2dd27b4 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -1,18 +1,16 @@ +import asyncio +import json +import subprocess +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine +from typing import IO, Optional, Union + import numpy as np -import asyncio, pathlib, subprocess, torch, json +import torch -from .audio import ( - SAMPLE_RATE, - N_FFT, - HOP_LENGTH, - N_FRAMES, - mel_filters, -) - -from .utils import PathType, ceildiv +from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters from .batching import Batcher -from typing import Optional, Union, IO, Tuple, Any, Type -from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable +from .utils import PathType, ceildiv + class AudioSink: def __init__(self, *, rate: int = SAMPLE_RATE, **kw): @@ -25,11 +23,19 @@ class AudioSink: def write(self, data): raise NotImplementedError + class ArrayStream(AudioSink): q: asyncio.Queue + def __init__( - self, *, device: Optional[Union[str, torch.device]] = None, - batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw): + self, + *, + device: Optional[Union[str, torch.device]] = None, + batch: int = 1, + n_mels: int = 80, + capacity: int = 1_000_000, + **kw, + ): super().__init__(**kw) self.q = asyncio.Queue(capacity) self.finished = asyncio.Event() @@ -43,6 +49,7 @@ class ArrayStream(AudioSink): return torch.zeros(shape, dtype=torch.float32, device=self.device) write_blockable: bool = True + def write(self, data: bytes) -> Optional[Coroutine]: if self.write_blockable: return self.q.put(data) @@ -51,11 +58,9 @@ class ArrayStream(AudioSink): return None def load(self, data: bytes) -> np.ndarray: - return np.frombuffer( - data, np.int16).flatten().astype(np.float32) / 32768.0 + return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0 - async def loader(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[np.ndarray]: + async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]: async for data in iterator: yield self.load(data) @@ -64,7 +69,8 @@ class ArrayStream(AudioSink): while not self.finished.is_set(): getter = asyncio.create_task(self.q.get()) done, pending = await asyncio.wait( - (waiter, getter), return_when=asyncio.FIRST_COMPLETED) + (waiter, getter), return_when=asyncio.FIRST_COMPLETED + ) if getter in done: yield getter.result() while not self.q.empty(): @@ -78,8 +84,10 @@ class ArrayStream(AudioSink): pass loading: Optional[Batcher] = None - async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[np.ndarray]: + + async def fft_offset( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[np.ndarray]: init = self.loader(iterator) if self.loading is None else self.loading self.loading = Batcher(init, HOP_LENGTH) _iterator = aiter(self.loading) @@ -89,7 +97,7 @@ class ArrayStream(AudioSink): window = np.concatenate((window, await anext(_iterator))) except StopAsyncIteration: return - window = np.pad(window, (N_FFT // 2, 0), 'reflect') + window = np.pad(window, (N_FFT // 2, 0), "reflect") yield window async for data in _iterator: yield data @@ -101,8 +109,9 @@ class ArrayStream(AudioSink): hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH return sees[hopped:] - async def window(self, iterator: AsyncIterable[bytes]) -> \ - AsyncIterator[torch.Tensor]: + async def window( + self, iterator: AsyncIterable[bytes] + ) -> AsyncIterator[torch.Tensor]: _iterator = self.fft_offset(iterator) async for data in _iterator: _data = torch.from_numpy(data) @@ -121,19 +130,23 @@ class ArrayStream(AudioSink): def dft(self, amp: torch.Tensor) -> torch.Tensor: return torch.stft( - amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, - return_complex=True) + amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True + ) # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 log_spec_bound: Optional[torch.Tensor] = None + def transform(self, stft: torch.Tensor) -> torch.Tensor: magnitudes = stft.abs() ** 2 mel_spec = self.filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() # causes values to not precisely match the original - self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \ - else torch.maximum(log_spec.max(), self.log_spec_bound) + self.log_spec_bound = ( + log_spec.max() + if self.log_spec_bound is None + else torch.maximum(log_spec.max(), self.log_spec_bound) + ) log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec @@ -143,6 +156,7 @@ class ArrayStream(AudioSink): # dft_pad: add ending content frames to match padding from a centered STFT dft_pad: bool = False + def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor: dft_pad = self.dft_pad if dft_pad is None else dft_pad if dft_pad: @@ -170,34 +184,36 @@ class ArrayStream(AudioSink): return self.runoff() staging: Optional[Batcher] = None - async def _push(self, sec: float, exact: bool = False) -> \ - AsyncIterator[torch.Tensor]: + + async def _push( + self, sec: float, exact: bool = False + ) -> AsyncIterator[torch.Tensor]: batching = int(sec * SAMPLE_RATE // HOP_LENGTH) - init = self.window(self.buffer()) if self.staging is None \ - else self.staging + init = self.window(self.buffer()) if self.staging is None else self.staging self.staging = Batcher(init, batching, exact=exact) async for frame in self.staging: batched = batching if exact else frame.shape[-1] cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0) self.offset += cutoff - self.spectogram = torch.cat(( - self.spectogram[:, cutoff:], frame), -1) + self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1) yield self.runoff() reader: Optional[Awaitable] = None + def start(self, **kw) -> None: if self.reader is None: self.reader = asyncio.create_task(self.read(**kw)) - async def push(self, sec: float, exact: bool=False, **kw) -> \ - AsyncIterator[torch.Tensor]: + async def push( + self, sec: float, exact: bool = False, **kw + ) -> AsyncIterator[torch.Tensor]: self.start(**kw) async for i in self._push(sec, exact): yield i assert self.reader is not None await self.reader - async def request(self, sec: float, exact: bool=True, **kw) -> torch.Tensor: + async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor: try: return await anext(self.push(sec, exact)) except StopAsyncIteration: @@ -224,17 +240,17 @@ class ArrayStream(AudioSink): def all_amplitudes(self, **kw) -> np.ndarray: return asyncio.run(self.amplitudes(**kw)) + class RawAudioFile(ArrayStream): - def __init__( - self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw', - **kw): + def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw): super().__init__(**kw) self.fname = fname self.period = period fp: Optional[IO[bytes]] = None + async def read(self) -> None: - fp = open(self.fname, 'rb') if self.fp is None else self.fp + fp = open(self.fname, "rb") if self.fp is None else self.fp data = fp.read(self.period) while len(data) != 0: io_hold = self.write(data) @@ -243,28 +259,33 @@ class RawAudioFile(ArrayStream): data = fp.read(self.period) self.finished.set() + class AudioFile(RawAudioFile): - def __init__( - self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav', - **kw): + def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw): assert not subprocess.run( - ["which", "ffmpeg"], stdout=subprocess.PIPE).returncode + ["which", "ffmpeg"], stdout=subprocess.PIPE + ).returncode super().__init__(period=period or -1, fname=fname, **kw) async def read(self) -> None: cmd = [ "ffmpeg", "-nostdin", - "-threads", "0", - "-i", self.fname, - "-f", "s16le", - "-ac", "1", - "-acodec", "pcm_s16le", - "-ar", str(self.rate), - "-" + "-threads", + "0", + "-i", + self.fname, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(self.rate), + "-", ] - ps = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) self.fp = ps.stdout await super().read() _, stderr = ps.communicate() @@ -277,13 +298,13 @@ class AudioFile(RawAudioFile): "ffprobe", "-hide_banner", "-show_format", - "-of", "json", - "-i", self.fname, + "-of", + "json", + "-i", + self.fname, ] - ps = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = ps.communicate() if ps.returncode not in (None, 0): raise RuntimeError(f"Failed to load audio: {stderr.decode()}") - return float(json.loads(stdout)['format']['duration']) - + return float(json.loads(stdout)["format"]["duration"]) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0964c8a..758c0b4 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,27 +1,30 @@ import argparse +import asyncio import os import traceback import warnings -import asyncio -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch import tqdm from .audio import ( + CHUNK_LENGTH, FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, SAMPLE_RATE, - CHUNK_LENGTH, pad_or_trim, ) +from .buffer import ArrayStream, AudioFile from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer from .utils import ( + PassthroughProperty, + PassthroughPropertyDefaults, exact_div, format_timestamp, get_end, @@ -30,14 +33,12 @@ from .utils import ( optional_float, optional_int, str2bool, - PassthroughProperty, - PassthroughPropertyDefaults, ) -from .buffer import ArrayStream, AudioFile if TYPE_CHECKING: from .model import Whisper + @dataclass class LanguageHypothesis: language: Optional[str] = None @@ -45,15 +46,17 @@ class LanguageHypothesis: evidence: int = 0 last: int = -1 + class Transcriber(metaclass=PassthroughPropertyDefaults): - prefix: str = '''"'\u201c\u00bf([{-''' - postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' + prefix: str = """"'\u201c\u00bf([{-""" + postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001""" punctuation: str = prefix + postfix verbose: Optional[bool] = None _decode_options: dict = {} decode_props: Tuple[str, ...] = ("fp16", "language", "task") + @property def decode_options(self) -> dict: for k in self.decode_props: @@ -68,6 +71,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): setattr(self, k, value[k]) dtype: torch.dtype = torch.float16 + @property def fp16(self) -> bool: return self.dtype == torch.float16 @@ -93,8 +97,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self._device = value if value == torch.device("cpu"): if torch.cuda.is_available(): - warnings.warn( - "Performing inference on CPU when CUDA is available") + warnings.warn("Performing inference on CPU when CUDA is available") self.fp16device() def fp16device(self) -> None: @@ -110,6 +113,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): prev: Optional[torch.Tensor] = None _latest: Optional[torch.Tensor] = None + @PassthroughProperty[Optional[torch.Tensor]](None).setter def latest(self, value: Optional[torch.Tensor]) -> None: self.prev = self._latest @@ -117,6 +121,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _hypothesis: LanguageHypothesis = LanguageHypothesis() _language: Optional[str] + @PassthroughProperty[Optional[str]](None).property def language(self) -> Optional[str]: if self._language is not None: @@ -125,20 +130,24 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return "en" if self.verbose: print( - "Detecting language using up to the first 30 seconds." - "Use `--language` to specify the language") + "Detecting language using up to the first 30 seconds." + "Use `--language` to specify the language" + ) if self.latest is None: return None if self._seek == self._hypothesis.last: return self._hypothesis.language if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: - mel = self.latest if self.prev is None else torch.cat( - (self.prev[:self.frame_offset], self.latest), -1) + mel = ( + self.latest + if self.prev is None + else torch.cat((self.prev[: self.frame_offset], self.latest), -1) + ) self._language = self.detect_language(mel) return self._language self._hypothesis.last = self._seek or 0 self._hypothesis.since += 1 - if 2 ** self._hypothesis.evidence < self._hypothesis.since: + if 2**self._hypothesis.evidence < self._hypothesis.since: return self._hypothesis.language self._hypothesis.since = 0 guess = self.detect_language() @@ -152,22 +161,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): self._seek_clips = None if isinstance(value, str): - self._clip_timestamps = tuple(map(float, value.split(","))) \ - if value else (0,) + self._clip_timestamps = ( + tuple(map(float, value.split(","))) if value else (0,) + ) else: self._clip_timestamps = tuple(value) or (0,) _seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None + @property def seek_clips(self) -> List[Tuple[int, Optional[int]]]: if self._seek_clips is None: seek_points = tuple( - round(ts * FRAMES_PER_SECOND) - for ts in self.clip_timestamps) + (None,) + round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps + ) + (None,) self._seek_clips = list(zip(seek_points[::2], seek_points[1::2])) return self._seek_clips _seek: Optional[int] + @PassthroughProperty[Optional[int]](None).property def seek(self) -> Optional[int]: return self.seek_clips[0][0] if self._seek is None else self._seek @@ -179,25 +191,30 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if value < len(clips): self.seek = clips[value][0] - time_offset = property(lambda self: float( - self.seek * HOP_LENGTH / SAMPLE_RATE)) - window_end_time = property(lambda self: float( - (self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)) + time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE)) + window_end_time = property( + lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + ) _temperature: Union[Optional[float], Tuple[float, ...]] + @PassthroughProperty[Union[Optional[float], Tuple[float, ...]]]( - (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter + (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + ).setter def temperature(self, value: Union[Optional[float], Tuple[float, ...]]): - self._temperature = (value,) if isinstance(value, (int, float)) else ( - Transcriber._temperature if value is None else value) + self._temperature = ( + (value,) + if isinstance(value, (int, float)) + else (Transcriber._temperature if value is None else value) + ) @PassthroughProperty("transcribe").setter def task(self, value: str): self._task = value if self.word_timestamps and value == "translate": warnings.warn( - "Word-level timestamps on translations may not be " - "reliable.") + "Word-level timestamps on translations may not be " "reliable." + ) @PassthroughProperty(False).setter def word_timestamps(self, value: bool): @@ -207,6 +224,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): get_tokenizer = staticmethod(get_tokenizer) _tokenizer: Optional[Tokenizer] = None _tokenizer_cache: Dict[str, Tokenizer] = {} + @property def tokenizer(self) -> Tokenizer: if self._tokenizer is None: @@ -235,6 +253,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _initial_prompt_tokens: Optional[List[int]] = None _initial_prompt_cache: Dict[Tokenizer, List[int]] = {} + @property def initial_prompt_tokens(self) -> List[int]: if self._initial_prompt_tokens is None: @@ -246,9 +265,9 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return [] if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( - " " + self.initial_prompt.strip()) - self._initial_prompt_tokens = \ - self._initial_prompt_cache[tokenizer] + " " + self.initial_prompt.strip() + ) + self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens @@ -256,23 +275,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): last_speech_timestamp: float = 0.0 frame_offset: int = 0 all_segments: List[dict] + def __init__( - self, - model: "Whisper", - *, - verbose: Optional[bool] = None, - temperature: Union[Optional[float], Tuple[float, ...]] = None, - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = prefix, - append_punctuations: str = postfix, - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - **decode_options): + self, + model: "Whisper", + *, + verbose: Optional[bool] = None, + temperature: Union[Optional[float], Tuple[float, ...]] = None, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = prefix, + append_punctuations: str = postfix, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + **decode_options, + ): """ Transcribe an audio file using Whisper @@ -374,14 +395,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): needs_fallback = False if self.compression_ratio_threshold is not None and ( - decode_result.compression_ratio > - self.compression_ratio_threshold): + decode_result.compression_ratio > self.compression_ratio_threshold + ): needs_fallback = True # too repetitive if self.logprob_threshold is not None and ( - decode_result.avg_logprob < self.logprob_threshold): + decode_result.avg_logprob < self.logprob_threshold + ): needs_fallback = True # average log probability is too low if self.no_speech_threshold is not None and ( - decode_result.no_speech_prob > self.no_speech_threshold): + decode_result.no_speech_prob > self.no_speech_threshold + ): needs_fallback = False # silence if not needs_fallback: break @@ -389,8 +412,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return decode_result def new_segment( - self, *, start: float, end: float, tokens: torch.Tensor, - result: DecodingResult) -> dict: + self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult + ) -> dict: _tokens = tokens.tolist() text_tokens = [token for token in _tokens if token < self.tokenizer.eot] return { @@ -422,9 +445,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def is_segment_anomaly(self, segment: Optional[dict]) -> bool: if segment is None or not segment["words"]: return False - words = [ - w for w in segment["words"] - if w["word"] not in self.punctuation][:8] + words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8] score = sum(self.word_anomaly_score(w) for w in words) return score >= 3 or score + 0.01 >= len(words) @@ -433,11 +454,15 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return next((s for s in segments if s["words"]), None) def reseek( - self, current_segments: List[dict], segment_size: int, - single_timestamp_ending: bool, tokens: torch.Tensor, - timestamp_tokens: torch.Tensor, result: DecodingResult): - consecutive = torch.where( - timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + tokens: torch.Tensor, + timestamp_tokens: torch.Tensor, + result: DecodingResult, + ): + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] consecutive.add_(1) if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens @@ -449,17 +474,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = ( - sliced_tokens[0].item() - - self.tokenizer.timestamp_begin) + sliced_tokens[0].item() - self.tokenizer.timestamp_begin + ) end_timestamp_pos = ( - sliced_tokens[-1].item() - - self.tokenizer.timestamp_begin) + sliced_tokens[-1].item() - self.tokenizer.timestamp_begin + ) current_segments.append( self.new_segment( - start=self.time_offset + \ - start_timestamp_pos * self.time_precision, - end=self.time_offset + \ - end_timestamp_pos * self.time_precision, + start=self.time_offset + + start_timestamp_pos * self.time_precision, + end=self.time_offset + end_timestamp_pos * self.time_precision, tokens=sliced_tokens, result=result, ) @@ -474,31 +498,42 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # otherwise, ignore the unfinished segment and seek to the last # timestamp last_timestamp_pos = ( - tokens[last_slice - 1].item() - - self.tokenizer.timestamp_begin) + tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin + ) self.seek += last_timestamp_pos * self.input_stride else: duration = segment_size * HOP_LENGTH / SAMPLE_RATE timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and \ - timestamps[-1].item() != self.tokenizer.timestamp_begin: + if ( + len(timestamps) > 0 + and timestamps[-1].item() != self.tokenizer.timestamp_begin + ): # no consecutive timestamps but it has a timestamp; use the last # one. - last_timestamp_pos = \ - timestamps[-1].item() - self.tokenizer.timestamp_begin + last_timestamp_pos = ( + timestamps[-1].item() - self.tokenizer.timestamp_begin + ) duration = last_timestamp_pos * self.time_precision - current_segments.append(self.new_segment( + current_segments.append( + self.new_segment( start=self.time_offset, end=self.time_offset + duration, tokens=tokens, - result=result)) + result=result, + ) + ) self.seek += segment_size def timestamp( - self, current_segments: List[dict], segment_size: int, - single_timestamp_ending: bool, mel_segment: torch.Tensor, - previous_seek: int, content_frames: int) -> bool: + self, + current_segments: List[dict], + segment_size: int, + single_timestamp_ending: bool, + mel_segment: torch.Tensor, + previous_seek: int, + content_frames: int, + ) -> bool: add_word_timestamps( segments=current_segments, model=self.model, @@ -512,8 +547,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if not single_timestamp_ending: last_word_end = get_end(current_segments) - if last_word_end is not None and \ - last_word_end > self.time_offset: + if last_word_end is not None and last_word_end > self.time_offset: self.seek = round(last_word_end * FRAMES_PER_SECOND) # skip silence before possible hallucinations @@ -521,24 +555,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): threshold = self.hallucination_silence_threshold if not single_timestamp_ending: last_word_end = get_end(current_segments) - if last_word_end is not None and \ - last_word_end > self.time_offset: - remaining_duration = \ - self.window_end_time - last_word_end + if last_word_end is not None and last_word_end > self.time_offset: + remaining_duration = self.window_end_time - last_word_end if remaining_duration > threshold: - self.seek = round( - last_word_end * FRAMES_PER_SECOND) + self.seek = round(last_word_end * FRAMES_PER_SECOND) else: self.seek = previous_seek + segment_size # if first segment might be a hallucination, skip leading silence first_segment = self.next_words_segment(current_segments) - if first_segment is not None and self.is_segment_anomaly( - first_segment): + if first_segment is not None and self.is_segment_anomaly(first_segment): gap = first_segment["start"] - self.time_offset if gap > threshold: - self.seek = previous_seek + round( - gap * FRAMES_PER_SECOND) + self.seek = previous_seek + round(gap * FRAMES_PER_SECOND) return True # skip silence before any possible hallucination that is @@ -550,14 +579,13 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if not segment["words"]: continue if self.is_segment_anomaly(segment): - next_segment = self.next_words_segment( - current_segments[si + 1 :]) + next_segment = self.next_words_segment(current_segments[si + 1 :]) if next_segment is not None: - hal_next_start = \ - next_segment["words"][0]["start"] + hal_next_start = next_segment["words"][0]["start"] else: - hal_next_start = self.time_offset + \ - segment_size * HOP_LENGTH / SAMPLE_RATE + hal_next_start = ( + self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE + ) silence_before = ( segment["start"] - hal_last_end > threshold or segment["start"] < threshold @@ -573,8 +601,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): max(self.time_offset + 1, segment["start"]) * FRAMES_PER_SECOND ) - if content_duration - segment["end"] < \ - threshold: + if content_duration - segment["end"] < threshold: self.seek = content_frames current_segments[si:] = [] break @@ -586,19 +613,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): return False def __call__( - self, mel: torch.Tensor, offset: int = 0, - single_pass: bool = False) -> dict: + self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False + ) -> dict: self.latest, self.frame_offset = mel, offset content_frames = mel.shape[-1] - N_FRAMES + offset - content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) # NOTE: This loop is obscurely flattened to make the diff readable. # A later commit should turn this into a simpler nested loop. # for seek_clip_start, seek_clip_end in seek_clips: # while seek < seek_clip_end while self.clip_idx < len(self.seek_clips): seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] - seek_clip_end = content_frames if seek_clip_end is None else \ - seek_clip_end + seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end if self.seek < seek_clip_start: self.seek = seek_clip_start if self.seek >= seek_clip_end: @@ -607,22 +632,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.clip_idx += 1 continue segment_size = min( - N_FRAMES, content_frames - self.seek, - seek_clip_end - self.seek) - mel_segment = mel[ - :, self.seek - offset : self.seek + segment_size - offset] - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to( - self.device).to(self.dtype) + N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek + ) + mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset] + mel_segment = ( + pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype) + ) - self.decode_options["prompt"] = \ - self.all_tokens[self.prompt_reset_since:] + self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :] result: DecodingResult = self.decode_with_fallback(mel_segment) if self.no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > self.no_speech_threshold - if self.logprob_threshold is not None and \ - result.avg_logprob > self.logprob_threshold: + if ( + self.logprob_threshold is not None + and result.avg_logprob > self.logprob_threshold + ): # don't skip if the logprob is high enough, despite the # no_speech_prob should_skip = False @@ -636,19 +662,27 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): current_segments: List[dict] = [] tokens = torch.tensor(result.tokens) - timestamp_tokens: torch.Tensor = tokens.ge( - self.tokenizer.timestamp_begin) - single_timestamp_ending = ( - timestamp_tokens[-2:].tolist() == [False, True]) + timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] self.reseek( - current_segments, segment_size, single_timestamp_ending, - tokens, timestamp_tokens, result) + current_segments, + segment_size, + single_timestamp_ending, + tokens, + timestamp_tokens, + result, + ) if self.word_timestamps: if self.timestamp( - current_segments, segment_size, single_timestamp_ending, - mel_segment, previous_seek, content_frames): + current_segments, + segment_size, + single_timestamp_ending, + mel_segment, + previous_seek, + content_frames, + ): continue if self.verbose: @@ -656,25 +690,29 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): start, end = segment["start"], segment["end"] text = segment["text"] line = ( - f"[{format_timestamp(start)} --> " - f"{format_timestamp(end)}] {text}") + f"[{format_timestamp(start)} --> " + f"{format_timestamp(end)}] {text}" + ) print(make_safe(line)) # if a segment is instantaneous or does not contain text, clear it for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or \ - segment["text"].strip() == "": + if segment["start"] == segment["end"] or segment["text"].strip() == "": segment["text"] = "" segment["tokens"] = [] segment["words"] = [] - self.all_segments.extend([ + self.all_segments.extend( + [ {"id": i, **segment} for i, segment in enumerate( - current_segments, start=len(self.all_segments))]) - self.all_tokens.extend([ - token for segment in current_segments - for token in segment["tokens"]]) + current_segments, start=len(self.all_segments) + ) + ] + ) + self.all_tokens.extend( + [token for segment in current_segments for token in segment["tokens"]] + ) if not self.condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used @@ -686,9 +724,12 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): break self.result = dict( - segments=self.all_segments, language=self.language, - text=self.tokenizer.decode( - self.all_tokens[len(self.initial_prompt_tokens):])) + segments=self.all_segments, + language=self.language, + text=self.tokenizer.decode( + self.all_tokens[len(self.initial_prompt_tokens) :] + ), + ) self.latest = None return self.result @@ -698,17 +739,18 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): def restore(self, offset: int) -> None: processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE while len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] >= seconds - if len(self.all_segments) == 1 else - self.all_segments[-2]["end"] > seconds): + self.all_segments[-1]["start"] >= seconds + if len(self.all_segments) == 1 + else self.all_segments[-2]["end"] > seconds + ): rewriting = self.all_segments.pop() processing += len(rewriting["tokens"]) - self.all_tokens = self.all_tokens[:len(self.all_tokens) - processing] + self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing] if len(self.all_segments) > 0 and ( - self.all_segments[-1]["start"] < seconds and - self.all_segments[-1]["end"] >= seconds): - self.seek = round( - self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) + self.all_segments[-1]["start"] < seconds + and self.all_segments[-1]["end"] >= seconds + ): + self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) else: self.seek = offset @@ -728,12 +770,16 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: class MinimalTranscriber(Transcriber): exact: bool = True chlen: float = CHUNK_LENGTH + async def process(self, stream: ArrayStream, **kw) -> dict: data = await stream.request(self.chlen, self.exact) while data.shape[-1] > 0: self(data, stream.offset, True) - t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \ - / FRAMES_PER_SECOND + CHUNK_LENGTH + t = ( + self.chlen + - (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND + + CHUNK_LENGTH + ) data = await stream.request(t, self.exact) return self.result @@ -755,12 +801,16 @@ class ProgressTranscriber(MinimalTranscriber): @PassthroughProperty(None).property def pbar(self): if self._pbar is None: - n = self.latest.shape[-1] if self.duration is None \ - else -int(self.duration * -FRAMES_PER_SECOND) + n = ( + self.latest.shape[-1] + if self.duration is None + else -int(self.duration * -FRAMES_PER_SECOND) + ) # show the progress bar when verbose is False # (if True, transcribed text will be printed) self._pbar = tqdm.tqdm( - total=n, unit="frames", disable=self.verbose is not False) + total=n, unit="frames", disable=self.verbose is not False + ) self._pbar.__enter__() return self._pbar @@ -784,10 +834,7 @@ class ProgressTranscriber(MinimalTranscriber): return await self.process(stream, **kw) -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor], - **kw): +def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): """ Transcribe an audio file using Whisper diff --git a/whisper/utils.py b/whisper/utils.py index 7bec242..27a0465 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -5,16 +5,7 @@ import re import sys import time import zlib -from typing import ( - Callable, - List, - Optional, - TextIO, - Union, - TypeVar, - Generic, - Any -) +from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union system_encoding = sys.getdefaultencoding() @@ -86,15 +77,17 @@ def format_timestamp( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) + def hms(sec: float) -> str: trim = sec < 3600 h = "" if trim else str(int(sec) // 3600) + ":" m_fill = " " if trim else "0" m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":" - s = str(int(sec) % 60).rjust(2, '0') + "." - c = str(round((sec % 1) * 100)).rjust(2, '0') + s = str(int(sec) % 60).rjust(2, "0") + "." + c = str(round((sec % 1) * 100)).rjust(2, "0") return h + m + s + c + def tod(seconds: float) -> str: return time.strftime("%H:%M:%S", time.localtime(seconds)) @@ -349,28 +342,34 @@ def get_writer( T = TypeVar("T") + # boilerplate for property with _{name} storage and passthrough getter/setter class PassthroughProperty(Generic[T]): def __init__(self, default: T): self.value = default f: Optional[Callable[[Any, T], None]] = None + def setter(self, f: Callable[[Any, T], None]): self.f = f return self g: Optional[property] = None + def property(self, g: Callable[[Any], T]): self.g = property(g) return self + class PassthroughPropertyDefaults(type): def __new__(cls, clsname, bases, attrs): def closure(f, v): def prop(self): return getattr(self, v) + def setter(self, value): setattr(self, v, value) + prop.__name__ = setter.__name__ = f return property(prop), setter @@ -383,4 +382,3 @@ class PassthroughPropertyDefaults(type): getter, setter = closure(k, private) updates[k] = (v.g or getter).setter(v.f or setter) return super().__new__(cls, clsname, bases, {**attrs, **updates}) -