From b4fd954955a8ff2a36e2c00222ecc875dd9c230c Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jul 2024 16:14:37 -0600 Subject: [PATCH] progress bar support and buffered cli option --- whisper/buffer.py | 18 +++++- whisper/transcribe.py | 134 ++++++++++++++++++++++++++++++------------ 2 files changed, 114 insertions(+), 38 deletions(-) diff --git a/whisper/buffer.py b/whisper/buffer.py index 3075db1..ab43b71 100644 --- a/whisper/buffer.py +++ b/whisper/buffer.py @@ -1,5 +1,5 @@ import numpy as np -import asyncio, pathlib, subprocess, torch +import asyncio, pathlib, subprocess, torch, json from .audio import ( SAMPLE_RATE, @@ -271,3 +271,19 @@ class AudioFile(RawAudioFile): if ps.returncode not in (None, 0): raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + @property + def duration(self): + cmd = [ + "ffprobe", + "-hide_banner", + "-show_format", + "-of", "json", + "-i", self.fname, + ] + ps = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = ps.communicate() + if ps.returncode not in (None, 0): + raise RuntimeError(f"Failed to load audio: {stderr.decode()}") + return float(json.loads(stdout)['format']['duration']) + diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d648196..17ec3cc 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,12 +2,13 @@ import argparse import os import traceback import warnings +import asyncio from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict from dataclasses import dataclass import numpy as np import torch -import tqdm # TODO +import tqdm from .audio import ( FRAMES_PER_SECOND, @@ -32,7 +33,7 @@ from .utils import ( PassthroughProperty, PassthroughPropertyDefaults, ) -from .buffer import AudioFile +from .buffer import ArrayStream, AudioFile if TYPE_CHECKING: from .model import Whisper @@ -49,7 +50,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' punctuation: str = prefix + postfix - verbose: bool = False + verbose: Optional[bool] = None _decode_options: dict = {} decode_props: Tuple[str, ...] = ("fp16", "language", "task") @@ -76,10 +77,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self.dtype = torch.float16 if value else torch.float32 self.fp16device() - @PassthroughProperty(None).setter - def model(self, value: "Whisper") -> None: + @PassthroughProperty[Optional["Whisper"]](None).setter + def model(self, value: Optional["Whisper"]) -> None: self._model = value - self.device = value.device + self.device = None if value is None else value.device self.input_stride = exact_div( N_FRAMES, self.model.dims.n_audio_ctx ) # mel frames per output token: 2 @@ -207,7 +208,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): _tokenizer: Optional[Tokenizer] = None _tokenizer_cache: Dict[str, Tokenizer] = {} @property - def tokenizer(self) -> Optional[Tokenizer]: + def tokenizer(self) -> Tokenizer: if self._tokenizer is None: lang = self.language if self._language is not None: @@ -221,8 +222,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): task=self.task, ) return self._tokenizer - if lang is None: - return None + assert lang is not None if lang not in self._tokenizer_cache: self._tokenizer_cache[lang] = self.get_tokenizer( self.model.is_multilingual, @@ -247,9 +247,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): if tokenizer not in self._initial_prompt_cache: self._initial_prompt_cache[tokenizer] = tokenizer.encode( " " + self.initial_prompt.strip()) - if self._tokenizer is not None: - self._initial_prompt_tokens = \ - self._initial_prompt_cache[tokenizer] + self._initial_prompt_tokens = \ + self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer] return self._initial_prompt_tokens @@ -275,8 +274,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): hallucination_silence_threshold: Optional[float] = None, **decode_options): self.model = model - if verbose is not None: - self.verbose = verbose + self.verbose = verbose self.temperature = temperature self.compression_ratio_threshold = compression_ratio_threshold self.logprob_threshold = logprob_threshold @@ -319,20 +317,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): needs_fallback = False # silence if not needs_fallback: break + assert decode_result is not None return decode_result def new_segment( self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult) -> dict: _tokens = tokens.tolist() - _tokenizer = self.tokenizer - assert _tokenizer is not None - text_tokens = [token for token in _tokens if token < _tokenizer.eot] + text_tokens = [token for token in _tokens if token < self.tokenizer.eot] return { "seek": self.seek, "start": start, "end": end, - "text": _tokenizer.decode(text_tokens), + "text": self.tokenizer.decode(text_tokens), "tokens": _tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, @@ -371,8 +368,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): self, current_segments: List[dict], segment_size: int, single_timestamp_ending: bool, tokens: torch.Tensor, timestamp_tokens: torch.Tensor, result: DecodingResult): - _tokenizer = self.tokenizer - assert _tokenizer is not None consecutive = torch.where( timestamp_tokens[:-1] & timestamp_tokens[1:])[0] consecutive.add_(1) @@ -387,10 +382,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): sliced_tokens = tokens[last_slice:current_slice] start_timestamp_pos = ( sliced_tokens[0].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) end_timestamp_pos = ( sliced_tokens[-1].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) current_segments.append( self.new_segment( start=self.time_offset + \ @@ -412,17 +407,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # timestamp last_timestamp_pos = ( tokens[last_slice - 1].item() - - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) self.seek += last_timestamp_pos * self.input_stride else: duration = segment_size * HOP_LENGTH / SAMPLE_RATE timestamps = tokens[timestamp_tokens.nonzero().flatten()] if len(timestamps) > 0 and \ - timestamps[-1].item() != _tokenizer.timestamp_begin: + timestamps[-1].item() != self.tokenizer.timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last # one. last_timestamp_pos = \ - timestamps[-1].item() - _tokenizer.timestamp_begin + timestamps[-1].item() - self.tokenizer.timestamp_begin duration = last_timestamp_pos * self.time_precision current_segments.append(self.new_segment( @@ -569,10 +564,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): current_segments: List[dict] = [] tokens = torch.tensor(result.tokens) - _tokenizer = self.tokenizer - assert _tokenizer is not None timestamp_tokens: torch.Tensor = tokens.ge( - _tokenizer.timestamp_begin) + self.tokenizer.timestamp_begin) single_timestamp_ending = ( timestamp_tokens[-2:].tolist() == [False, True]) @@ -615,19 +608,22 @@ class Transcriber(metaclass=PassthroughPropertyDefaults): # do not feed the prompt tokens if a high temperature was used self.prompt_reset_since = len(self.all_tokens) + self.reporthook() + if single_pass: break - _tokenizer = self.tokenizer - assert _tokenizer is not None - res = dict( + self.result = dict( segments=self.all_segments, language=self.language, - text=_tokenizer.decode( + text=self.tokenizer.decode( self.all_tokens[len(self.initial_prompt_tokens):])) self.latest = None - return res + return self.result - def restore(self, offset: int): + def reporthook(self) -> None: + pass + + def restore(self, offset: int) -> None: processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE while len(self.all_segments) > 0 and ( self.all_segments[-1]["start"] >= seconds @@ -652,16 +648,78 @@ class InMemoryAudio(AudioFile): def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: if isinstance(audio, str): return InMemoryAudio(fname=audio).sequential() - if isinstance(audio, np.dtype): + if isinstance(audio, np.ndarray): return torch.from_numpy(audio) return audio +class MinimalTranscriber(Transcriber): + exact: bool = True + chlen: float = CHUNK_LENGTH + async def process(self, stream: ArrayStream, **kw) -> dict: + data = await stream.request(self.chlen, self.exact) + while data.shape[-1] > 0: + self(data, stream.offset, True) + t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \ + / FRAMES_PER_SECOND + CHUNK_LENGTH + data = await stream.request(t, self.exact) + return self.result + + +class ProgressTranscriber(MinimalTranscriber): + def __init__(self, *a, duration: Optional[float] = None, **kw): + super().__init__(*a, **kw) + self.duration, self.progress = duration, 0 + + def __call__(self, *a, **kw) -> dict: + if self._pbar is None: + try: + return super().__call__(*a, **kw) + finally: + self.close() + else: + return super().__call__(*a, **kw) + + @PassthroughProperty(None).property + def pbar(self): + if self._pbar is None: + n = self.latest.shape[-1] if self.duration is None \ + else -int(self.duration * -FRAMES_PER_SECOND) + self._pbar = tqdm.tqdm( + total=n, unit="frames", disable=self.verbose is not False) + self._pbar.__enter__() + return self._pbar + + def reporthook(self) -> None: + update_to = min(self._seek, self.frame_offset + self.latest.shape[-1]) + self.pbar.update(update_to - self.progress) + self.progress = update_to + + def close(self): + self.pbar.__exit__(None, None, None) + + async def process(self, stream: ArrayStream, **kw) -> dict: + self.pbar + try: + return await super().process(stream, **kw) + finally: + self.close() + + async def progressive(self, stream: AudioFile, **kw) -> dict: + self.duration = stream.duration + return await self.process(stream, **kw) + + def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw): - return Transcriber(model, **kw)(audio_tensor(audio)) + return ProgressTranscriber(model, **kw)(audio_tensor(audio)) + + +def buffered_transcribe(model: "Whisper", audio: str, **kw): + transcriber = ProgressTranscriber(model, **kw) + return asyncio.run(transcriber.progressive(AudioFile(fname=audio))) def cli(): @@ -712,6 +770,7 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once") # fmt: on args = parser.parse_args().__dict__ @@ -741,6 +800,7 @@ def cli(): from . import load_model model = load_model(model_name, device=device, download_root=model_dir) + transcriber = buffered_transcribe if args.pop("buffered") else transcribe writer = get_writer(output_format, output_dir) word_options = [ @@ -760,7 +820,7 @@ def cli(): writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): try: - result = transcribe(model, audio_path, temperature=temperature, **args) + result = transcriber(model, audio_path, temperature=temperature, **args) writer(result, audio_path, **writer_args) except Exception as e: traceback.print_exc()