progress bar support and buffered cli option

This commit is contained in:
Kent Slaney 2024-07-14 16:14:37 -06:00
parent 4ccbd70012
commit b4fd954955
2 changed files with 114 additions and 38 deletions

View File

@ -1,5 +1,5 @@
import numpy as np import numpy as np
import asyncio, pathlib, subprocess, torch import asyncio, pathlib, subprocess, torch, json
from .audio import ( from .audio import (
SAMPLE_RATE, SAMPLE_RATE,
@ -271,3 +271,19 @@ class AudioFile(RawAudioFile):
if ps.returncode not in (None, 0): if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}") raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
@property
def duration(self):
cmd = [
"ffprobe",
"-hide_banner",
"-show_format",
"-of", "json",
"-i", self.fname,
]
ps = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = ps.communicate()
if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
return float(json.loads(stdout)['format']['duration'])

View File

@ -2,12 +2,13 @@ import argparse
import os import os
import traceback import traceback
import warnings import warnings
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import torch import torch
import tqdm # TODO import tqdm
from .audio import ( from .audio import (
FRAMES_PER_SECOND, FRAMES_PER_SECOND,
@ -32,7 +33,7 @@ from .utils import (
PassthroughProperty, PassthroughProperty,
PassthroughPropertyDefaults, PassthroughPropertyDefaults,
) )
from .buffer import AudioFile from .buffer import ArrayStream, AudioFile
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
@ -49,7 +50,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
punctuation: str = prefix + postfix punctuation: str = prefix + postfix
verbose: bool = False verbose: Optional[bool] = None
_decode_options: dict = {} _decode_options: dict = {}
decode_props: Tuple[str, ...] = ("fp16", "language", "task") decode_props: Tuple[str, ...] = ("fp16", "language", "task")
@ -76,10 +77,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self.dtype = torch.float16 if value else torch.float32 self.dtype = torch.float16 if value else torch.float32
self.fp16device() self.fp16device()
@PassthroughProperty(None).setter @PassthroughProperty[Optional["Whisper"]](None).setter
def model(self, value: "Whisper") -> None: def model(self, value: Optional["Whisper"]) -> None:
self._model = value self._model = value
self.device = value.device self.device = None if value is None else value.device
self.input_stride = exact_div( self.input_stride = exact_div(
N_FRAMES, self.model.dims.n_audio_ctx N_FRAMES, self.model.dims.n_audio_ctx
) # mel frames per output token: 2 ) # mel frames per output token: 2
@ -207,7 +208,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
_tokenizer: Optional[Tokenizer] = None _tokenizer: Optional[Tokenizer] = None
_tokenizer_cache: Dict[str, Tokenizer] = {} _tokenizer_cache: Dict[str, Tokenizer] = {}
@property @property
def tokenizer(self) -> Optional[Tokenizer]: def tokenizer(self) -> Tokenizer:
if self._tokenizer is None: if self._tokenizer is None:
lang = self.language lang = self.language
if self._language is not None: if self._language is not None:
@ -221,8 +222,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
task=self.task, task=self.task,
) )
return self._tokenizer return self._tokenizer
if lang is None: assert lang is not None
return None
if lang not in self._tokenizer_cache: if lang not in self._tokenizer_cache:
self._tokenizer_cache[lang] = self.get_tokenizer( self._tokenizer_cache[lang] = self.get_tokenizer(
self.model.is_multilingual, self.model.is_multilingual,
@ -247,9 +247,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if tokenizer not in self._initial_prompt_cache: if tokenizer not in self._initial_prompt_cache:
self._initial_prompt_cache[tokenizer] = tokenizer.encode( self._initial_prompt_cache[tokenizer] = tokenizer.encode(
" " + self.initial_prompt.strip()) " " + self.initial_prompt.strip())
if self._tokenizer is not None: self._initial_prompt_tokens = \
self._initial_prompt_tokens = \ self._initial_prompt_cache[tokenizer]
self._initial_prompt_cache[tokenizer]
return self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer]
return self._initial_prompt_tokens return self._initial_prompt_tokens
@ -275,8 +274,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
**decode_options): **decode_options):
self.model = model self.model = model
if verbose is not None: self.verbose = verbose
self.verbose = verbose
self.temperature = temperature self.temperature = temperature
self.compression_ratio_threshold = compression_ratio_threshold self.compression_ratio_threshold = compression_ratio_threshold
self.logprob_threshold = logprob_threshold self.logprob_threshold = logprob_threshold
@ -319,20 +317,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
needs_fallback = False # silence needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
break break
assert decode_result is not None
return decode_result return decode_result
def new_segment( def new_segment(
self, *, start: float, end: float, tokens: torch.Tensor, self, *, start: float, end: float, tokens: torch.Tensor,
result: DecodingResult) -> dict: result: DecodingResult) -> dict:
_tokens = tokens.tolist() _tokens = tokens.tolist()
_tokenizer = self.tokenizer text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
assert _tokenizer is not None
text_tokens = [token for token in _tokens if token < _tokenizer.eot]
return { return {
"seek": self.seek, "seek": self.seek,
"start": start, "start": start,
"end": end, "end": end,
"text": _tokenizer.decode(text_tokens), "text": self.tokenizer.decode(text_tokens),
"tokens": _tokens, "tokens": _tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
@ -371,8 +368,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self, current_segments: List[dict], segment_size: int, self, current_segments: List[dict], segment_size: int,
single_timestamp_ending: bool, tokens: torch.Tensor, single_timestamp_ending: bool, tokens: torch.Tensor,
timestamp_tokens: torch.Tensor, result: DecodingResult): timestamp_tokens: torch.Tensor, result: DecodingResult):
_tokenizer = self.tokenizer
assert _tokenizer is not None
consecutive = torch.where( consecutive = torch.where(
timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1) consecutive.add_(1)
@ -387,10 +382,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = ( start_timestamp_pos = (
sliced_tokens[0].item() - sliced_tokens[0].item() -
_tokenizer.timestamp_begin) self.tokenizer.timestamp_begin)
end_timestamp_pos = ( end_timestamp_pos = (
sliced_tokens[-1].item() - sliced_tokens[-1].item() -
_tokenizer.timestamp_begin) self.tokenizer.timestamp_begin)
current_segments.append( current_segments.append(
self.new_segment( self.new_segment(
start=self.time_offset + \ start=self.time_offset + \
@ -412,17 +407,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
# timestamp # timestamp
last_timestamp_pos = ( last_timestamp_pos = (
tokens[last_slice - 1].item() - tokens[last_slice - 1].item() -
_tokenizer.timestamp_begin) self.tokenizer.timestamp_begin)
self.seek += last_timestamp_pos * self.input_stride self.seek += last_timestamp_pos * self.input_stride
else: else:
duration = segment_size * HOP_LENGTH / SAMPLE_RATE duration = segment_size * HOP_LENGTH / SAMPLE_RATE
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and \ if len(timestamps) > 0 and \
timestamps[-1].item() != _tokenizer.timestamp_begin: timestamps[-1].item() != self.tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last # no consecutive timestamps but it has a timestamp; use the last
# one. # one.
last_timestamp_pos = \ last_timestamp_pos = \
timestamps[-1].item() - _tokenizer.timestamp_begin timestamps[-1].item() - self.tokenizer.timestamp_begin
duration = last_timestamp_pos * self.time_precision duration = last_timestamp_pos * self.time_precision
current_segments.append(self.new_segment( current_segments.append(self.new_segment(
@ -569,10 +564,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
current_segments: List[dict] = [] current_segments: List[dict] = []
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
_tokenizer = self.tokenizer
assert _tokenizer is not None
timestamp_tokens: torch.Tensor = tokens.ge( timestamp_tokens: torch.Tensor = tokens.ge(
_tokenizer.timestamp_begin) self.tokenizer.timestamp_begin)
single_timestamp_ending = ( single_timestamp_ending = (
timestamp_tokens[-2:].tolist() == [False, True]) timestamp_tokens[-2:].tolist() == [False, True])
@ -615,19 +608,22 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
self.prompt_reset_since = len(self.all_tokens) self.prompt_reset_since = len(self.all_tokens)
self.reporthook()
if single_pass: if single_pass:
break break
_tokenizer = self.tokenizer self.result = dict(
assert _tokenizer is not None
res = dict(
segments=self.all_segments, language=self.language, segments=self.all_segments, language=self.language,
text=_tokenizer.decode( text=self.tokenizer.decode(
self.all_tokens[len(self.initial_prompt_tokens):])) self.all_tokens[len(self.initial_prompt_tokens):]))
self.latest = None self.latest = None
return res return self.result
def restore(self, offset: int): def reporthook(self) -> None:
pass
def restore(self, offset: int) -> None:
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
while len(self.all_segments) > 0 and ( while len(self.all_segments) > 0 and (
self.all_segments[-1]["start"] >= seconds self.all_segments[-1]["start"] >= seconds
@ -652,16 +648,78 @@ class InMemoryAudio(AudioFile):
def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor: def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(audio, str): if isinstance(audio, str):
return InMemoryAudio(fname=audio).sequential() return InMemoryAudio(fname=audio).sequential()
if isinstance(audio, np.dtype): if isinstance(audio, np.ndarray):
return torch.from_numpy(audio) return torch.from_numpy(audio)
return audio return audio
class MinimalTranscriber(Transcriber):
exact: bool = True
chlen: float = CHUNK_LENGTH
async def process(self, stream: ArrayStream, **kw) -> dict:
data = await stream.request(self.chlen, self.exact)
while data.shape[-1] > 0:
self(data, stream.offset, True)
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \
/ FRAMES_PER_SECOND + CHUNK_LENGTH
data = await stream.request(t, self.exact)
return self.result
class ProgressTranscriber(MinimalTranscriber):
def __init__(self, *a, duration: Optional[float] = None, **kw):
super().__init__(*a, **kw)
self.duration, self.progress = duration, 0
def __call__(self, *a, **kw) -> dict:
if self._pbar is None:
try:
return super().__call__(*a, **kw)
finally:
self.close()
else:
return super().__call__(*a, **kw)
@PassthroughProperty(None).property
def pbar(self):
if self._pbar is None:
n = self.latest.shape[-1] if self.duration is None \
else -int(self.duration * -FRAMES_PER_SECOND)
self._pbar = tqdm.tqdm(
total=n, unit="frames", disable=self.verbose is not False)
self._pbar.__enter__()
return self._pbar
def reporthook(self) -> None:
update_to = min(self._seek, self.frame_offset + self.latest.shape[-1])
self.pbar.update(update_to - self.progress)
self.progress = update_to
def close(self):
self.pbar.__exit__(None, None, None)
async def process(self, stream: ArrayStream, **kw) -> dict:
self.pbar
try:
return await super().process(stream, **kw)
finally:
self.close()
async def progressive(self, stream: AudioFile, **kw) -> dict:
self.duration = stream.duration
return await self.process(stream, **kw)
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
**kw): **kw):
return Transcriber(model, **kw)(audio_tensor(audio)) return ProgressTranscriber(model, **kw)(audio_tensor(audio))
def buffered_transcribe(model: "Whisper", audio: str, **kw):
transcriber = ProgressTranscriber(model, **kw)
return asyncio.run(transcriber.progressive(AudioFile(fname=audio)))
def cli(): def cli():
@ -712,6 +770,7 @@ def cli():
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -741,6 +800,7 @@ def cli():
from . import load_model from . import load_model
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
transcriber = buffered_transcribe if args.pop("buffered") else transcribe
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = [ word_options = [
@ -760,7 +820,7 @@ def cli():
writer_args = {arg: args.pop(arg) for arg in word_options} writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
try: try:
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcriber(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args) writer(result, audio_path, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()