mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
progress bar support and buffered cli option
This commit is contained in:
parent
4ccbd70012
commit
b4fd954955
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
import asyncio, pathlib, subprocess, torch
|
||||
import asyncio, pathlib, subprocess, torch, json
|
||||
|
||||
from .audio import (
|
||||
SAMPLE_RATE,
|
||||
@ -271,3 +271,19 @@ class AudioFile(RawAudioFile):
|
||||
if ps.returncode not in (None, 0):
|
||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-hide_banner",
|
||||
"-show_format",
|
||||
"-of", "json",
|
||||
"-i", self.fname,
|
||||
]
|
||||
ps = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = ps.communicate()
|
||||
if ps.returncode not in (None, 0):
|
||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||
return float(json.loads(stdout)['format']['duration'])
|
||||
|
||||
|
||||
@ -2,12 +2,13 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm # TODO
|
||||
import tqdm
|
||||
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
@ -32,7 +33,7 @@ from .utils import (
|
||||
PassthroughProperty,
|
||||
PassthroughPropertyDefaults,
|
||||
)
|
||||
from .buffer import AudioFile
|
||||
from .buffer import ArrayStream, AudioFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@ -49,7 +50,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
|
||||
punctuation: str = prefix + postfix
|
||||
|
||||
verbose: bool = False
|
||||
verbose: Optional[bool] = None
|
||||
|
||||
_decode_options: dict = {}
|
||||
decode_props: Tuple[str, ...] = ("fp16", "language", "task")
|
||||
@ -76,10 +77,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
self.dtype = torch.float16 if value else torch.float32
|
||||
self.fp16device()
|
||||
|
||||
@PassthroughProperty(None).setter
|
||||
def model(self, value: "Whisper") -> None:
|
||||
@PassthroughProperty[Optional["Whisper"]](None).setter
|
||||
def model(self, value: Optional["Whisper"]) -> None:
|
||||
self._model = value
|
||||
self.device = value.device
|
||||
self.device = None if value is None else value.device
|
||||
self.input_stride = exact_div(
|
||||
N_FRAMES, self.model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
@ -207,7 +208,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
_tokenizer: Optional[Tokenizer] = None
|
||||
_tokenizer_cache: Dict[str, Tokenizer] = {}
|
||||
@property
|
||||
def tokenizer(self) -> Optional[Tokenizer]:
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if self._tokenizer is None:
|
||||
lang = self.language
|
||||
if self._language is not None:
|
||||
@ -221,8 +222,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
task=self.task,
|
||||
)
|
||||
return self._tokenizer
|
||||
if lang is None:
|
||||
return None
|
||||
assert lang is not None
|
||||
if lang not in self._tokenizer_cache:
|
||||
self._tokenizer_cache[lang] = self.get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
@ -247,7 +247,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
if tokenizer not in self._initial_prompt_cache:
|
||||
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
||||
" " + self.initial_prompt.strip())
|
||||
if self._tokenizer is not None:
|
||||
self._initial_prompt_tokens = \
|
||||
self._initial_prompt_cache[tokenizer]
|
||||
return self._initial_prompt_cache[tokenizer]
|
||||
@ -275,7 +274,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options):
|
||||
self.model = model
|
||||
if verbose is not None:
|
||||
self.verbose = verbose
|
||||
self.temperature = temperature
|
||||
self.compression_ratio_threshold = compression_ratio_threshold
|
||||
@ -319,20 +317,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
assert decode_result is not None
|
||||
return decode_result
|
||||
|
||||
def new_segment(
|
||||
self, *, start: float, end: float, tokens: torch.Tensor,
|
||||
result: DecodingResult) -> dict:
|
||||
_tokens = tokens.tolist()
|
||||
_tokenizer = self.tokenizer
|
||||
assert _tokenizer is not None
|
||||
text_tokens = [token for token in _tokens if token < _tokenizer.eot]
|
||||
text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
|
||||
return {
|
||||
"seek": self.seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": _tokenizer.decode(text_tokens),
|
||||
"text": self.tokenizer.decode(text_tokens),
|
||||
"tokens": _tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
@ -371,8 +368,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
self, current_segments: List[dict], segment_size: int,
|
||||
single_timestamp_ending: bool, tokens: torch.Tensor,
|
||||
timestamp_tokens: torch.Tensor, result: DecodingResult):
|
||||
_tokenizer = self.tokenizer
|
||||
assert _tokenizer is not None
|
||||
consecutive = torch.where(
|
||||
timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
@ -387,10 +382,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() -
|
||||
_tokenizer.timestamp_begin)
|
||||
self.tokenizer.timestamp_begin)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() -
|
||||
_tokenizer.timestamp_begin)
|
||||
self.tokenizer.timestamp_begin)
|
||||
current_segments.append(
|
||||
self.new_segment(
|
||||
start=self.time_offset + \
|
||||
@ -412,17 +407,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
# timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() -
|
||||
_tokenizer.timestamp_begin)
|
||||
self.tokenizer.timestamp_begin)
|
||||
self.seek += last_timestamp_pos * self.input_stride
|
||||
else:
|
||||
duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and \
|
||||
timestamps[-1].item() != _tokenizer.timestamp_begin:
|
||||
timestamps[-1].item() != self.tokenizer.timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last
|
||||
# one.
|
||||
last_timestamp_pos = \
|
||||
timestamps[-1].item() - _tokenizer.timestamp_begin
|
||||
timestamps[-1].item() - self.tokenizer.timestamp_begin
|
||||
duration = last_timestamp_pos * self.time_precision
|
||||
|
||||
current_segments.append(self.new_segment(
|
||||
@ -569,10 +564,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
current_segments: List[dict] = []
|
||||
|
||||
tokens = torch.tensor(result.tokens)
|
||||
_tokenizer = self.tokenizer
|
||||
assert _tokenizer is not None
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(
|
||||
_tokenizer.timestamp_begin)
|
||||
self.tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = (
|
||||
timestamp_tokens[-2:].tolist() == [False, True])
|
||||
|
||||
@ -615,19 +608,22 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
self.prompt_reset_since = len(self.all_tokens)
|
||||
|
||||
self.reporthook()
|
||||
|
||||
if single_pass:
|
||||
break
|
||||
|
||||
_tokenizer = self.tokenizer
|
||||
assert _tokenizer is not None
|
||||
res = dict(
|
||||
self.result = dict(
|
||||
segments=self.all_segments, language=self.language,
|
||||
text=_tokenizer.decode(
|
||||
text=self.tokenizer.decode(
|
||||
self.all_tokens[len(self.initial_prompt_tokens):]))
|
||||
self.latest = None
|
||||
return res
|
||||
return self.result
|
||||
|
||||
def restore(self, offset: int):
|
||||
def reporthook(self) -> None:
|
||||
pass
|
||||
|
||||
def restore(self, offset: int) -> None:
|
||||
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
|
||||
while len(self.all_segments) > 0 and (
|
||||
self.all_segments[-1]["start"] >= seconds
|
||||
@ -652,16 +648,78 @@ class InMemoryAudio(AudioFile):
|
||||
def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
if isinstance(audio, str):
|
||||
return InMemoryAudio(fname=audio).sequential()
|
||||
if isinstance(audio, np.dtype):
|
||||
if isinstance(audio, np.ndarray):
|
||||
return torch.from_numpy(audio)
|
||||
return audio
|
||||
|
||||
|
||||
class MinimalTranscriber(Transcriber):
|
||||
exact: bool = True
|
||||
chlen: float = CHUNK_LENGTH
|
||||
async def process(self, stream: ArrayStream, **kw) -> dict:
|
||||
data = await stream.request(self.chlen, self.exact)
|
||||
while data.shape[-1] > 0:
|
||||
self(data, stream.offset, True)
|
||||
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \
|
||||
/ FRAMES_PER_SECOND + CHUNK_LENGTH
|
||||
data = await stream.request(t, self.exact)
|
||||
return self.result
|
||||
|
||||
|
||||
class ProgressTranscriber(MinimalTranscriber):
|
||||
def __init__(self, *a, duration: Optional[float] = None, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
self.duration, self.progress = duration, 0
|
||||
|
||||
def __call__(self, *a, **kw) -> dict:
|
||||
if self._pbar is None:
|
||||
try:
|
||||
return super().__call__(*a, **kw)
|
||||
finally:
|
||||
self.close()
|
||||
else:
|
||||
return super().__call__(*a, **kw)
|
||||
|
||||
@PassthroughProperty(None).property
|
||||
def pbar(self):
|
||||
if self._pbar is None:
|
||||
n = self.latest.shape[-1] if self.duration is None \
|
||||
else -int(self.duration * -FRAMES_PER_SECOND)
|
||||
self._pbar = tqdm.tqdm(
|
||||
total=n, unit="frames", disable=self.verbose is not False)
|
||||
self._pbar.__enter__()
|
||||
return self._pbar
|
||||
|
||||
def reporthook(self) -> None:
|
||||
update_to = min(self._seek, self.frame_offset + self.latest.shape[-1])
|
||||
self.pbar.update(update_to - self.progress)
|
||||
self.progress = update_to
|
||||
|
||||
def close(self):
|
||||
self.pbar.__exit__(None, None, None)
|
||||
|
||||
async def process(self, stream: ArrayStream, **kw) -> dict:
|
||||
self.pbar
|
||||
try:
|
||||
return await super().process(stream, **kw)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
async def progressive(self, stream: AudioFile, **kw) -> dict:
|
||||
self.duration = stream.duration
|
||||
return await self.process(stream, **kw)
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
**kw):
|
||||
return Transcriber(model, **kw)(audio_tensor(audio))
|
||||
return ProgressTranscriber(model, **kw)(audio_tensor(audio))
|
||||
|
||||
|
||||
def buffered_transcribe(model: "Whisper", audio: str, **kw):
|
||||
transcriber = ProgressTranscriber(model, **kw)
|
||||
return asyncio.run(transcriber.progressive(AudioFile(fname=audio)))
|
||||
|
||||
|
||||
def cli():
|
||||
@ -712,6 +770,7 @@ def cli():
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||
parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -741,6 +800,7 @@ def cli():
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
transcriber = buffered_transcribe if args.pop("buffered") else transcribe
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = [
|
||||
@ -760,7 +820,7 @@ def cli():
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
try:
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
result = transcriber(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user