mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
progress bar support and buffered cli option
This commit is contained in:
parent
4ccbd70012
commit
b4fd954955
@ -1,5 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio, pathlib, subprocess, torch
|
import asyncio, pathlib, subprocess, torch, json
|
||||||
|
|
||||||
from .audio import (
|
from .audio import (
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
@ -271,3 +271,19 @@ class AudioFile(RawAudioFile):
|
|||||||
if ps.returncode not in (None, 0):
|
if ps.returncode not in (None, 0):
|
||||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self):
|
||||||
|
cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-hide_banner",
|
||||||
|
"-show_format",
|
||||||
|
"-of", "json",
|
||||||
|
"-i", self.fname,
|
||||||
|
]
|
||||||
|
ps = subprocess.Popen(
|
||||||
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
stdout, stderr = ps.communicate()
|
||||||
|
if ps.returncode not in (None, 0):
|
||||||
|
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||||
|
return float(json.loads(stdout)['format']['duration'])
|
||||||
|
|
||||||
|
|||||||
@ -2,12 +2,13 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm # TODO
|
import tqdm
|
||||||
|
|
||||||
from .audio import (
|
from .audio import (
|
||||||
FRAMES_PER_SECOND,
|
FRAMES_PER_SECOND,
|
||||||
@ -32,7 +33,7 @@ from .utils import (
|
|||||||
PassthroughProperty,
|
PassthroughProperty,
|
||||||
PassthroughPropertyDefaults,
|
PassthroughPropertyDefaults,
|
||||||
)
|
)
|
||||||
from .buffer import AudioFile
|
from .buffer import ArrayStream, AudioFile
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
@ -49,7 +50,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
|
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
|
||||||
punctuation: str = prefix + postfix
|
punctuation: str = prefix + postfix
|
||||||
|
|
||||||
verbose: bool = False
|
verbose: Optional[bool] = None
|
||||||
|
|
||||||
_decode_options: dict = {}
|
_decode_options: dict = {}
|
||||||
decode_props: Tuple[str, ...] = ("fp16", "language", "task")
|
decode_props: Tuple[str, ...] = ("fp16", "language", "task")
|
||||||
@ -76,10 +77,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
self.dtype = torch.float16 if value else torch.float32
|
self.dtype = torch.float16 if value else torch.float32
|
||||||
self.fp16device()
|
self.fp16device()
|
||||||
|
|
||||||
@PassthroughProperty(None).setter
|
@PassthroughProperty[Optional["Whisper"]](None).setter
|
||||||
def model(self, value: "Whisper") -> None:
|
def model(self, value: Optional["Whisper"]) -> None:
|
||||||
self._model = value
|
self._model = value
|
||||||
self.device = value.device
|
self.device = None if value is None else value.device
|
||||||
self.input_stride = exact_div(
|
self.input_stride = exact_div(
|
||||||
N_FRAMES, self.model.dims.n_audio_ctx
|
N_FRAMES, self.model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
@ -207,7 +208,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
_tokenizer: Optional[Tokenizer] = None
|
_tokenizer: Optional[Tokenizer] = None
|
||||||
_tokenizer_cache: Dict[str, Tokenizer] = {}
|
_tokenizer_cache: Dict[str, Tokenizer] = {}
|
||||||
@property
|
@property
|
||||||
def tokenizer(self) -> Optional[Tokenizer]:
|
def tokenizer(self) -> Tokenizer:
|
||||||
if self._tokenizer is None:
|
if self._tokenizer is None:
|
||||||
lang = self.language
|
lang = self.language
|
||||||
if self._language is not None:
|
if self._language is not None:
|
||||||
@ -221,8 +222,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
task=self.task,
|
task=self.task,
|
||||||
)
|
)
|
||||||
return self._tokenizer
|
return self._tokenizer
|
||||||
if lang is None:
|
assert lang is not None
|
||||||
return None
|
|
||||||
if lang not in self._tokenizer_cache:
|
if lang not in self._tokenizer_cache:
|
||||||
self._tokenizer_cache[lang] = self.get_tokenizer(
|
self._tokenizer_cache[lang] = self.get_tokenizer(
|
||||||
self.model.is_multilingual,
|
self.model.is_multilingual,
|
||||||
@ -247,9 +247,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
if tokenizer not in self._initial_prompt_cache:
|
if tokenizer not in self._initial_prompt_cache:
|
||||||
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
||||||
" " + self.initial_prompt.strip())
|
" " + self.initial_prompt.strip())
|
||||||
if self._tokenizer is not None:
|
self._initial_prompt_tokens = \
|
||||||
self._initial_prompt_tokens = \
|
self._initial_prompt_cache[tokenizer]
|
||||||
self._initial_prompt_cache[tokenizer]
|
|
||||||
return self._initial_prompt_cache[tokenizer]
|
return self._initial_prompt_cache[tokenizer]
|
||||||
return self._initial_prompt_tokens
|
return self._initial_prompt_tokens
|
||||||
|
|
||||||
@ -275,8 +274,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
hallucination_silence_threshold: Optional[float] = None,
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options):
|
**decode_options):
|
||||||
self.model = model
|
self.model = model
|
||||||
if verbose is not None:
|
self.verbose = verbose
|
||||||
self.verbose = verbose
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.compression_ratio_threshold = compression_ratio_threshold
|
self.compression_ratio_threshold = compression_ratio_threshold
|
||||||
self.logprob_threshold = logprob_threshold
|
self.logprob_threshold = logprob_threshold
|
||||||
@ -319,20 +317,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
needs_fallback = False # silence
|
needs_fallback = False # silence
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
break
|
break
|
||||||
|
assert decode_result is not None
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
def new_segment(
|
def new_segment(
|
||||||
self, *, start: float, end: float, tokens: torch.Tensor,
|
self, *, start: float, end: float, tokens: torch.Tensor,
|
||||||
result: DecodingResult) -> dict:
|
result: DecodingResult) -> dict:
|
||||||
_tokens = tokens.tolist()
|
_tokens = tokens.tolist()
|
||||||
_tokenizer = self.tokenizer
|
text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
|
||||||
assert _tokenizer is not None
|
|
||||||
text_tokens = [token for token in _tokens if token < _tokenizer.eot]
|
|
||||||
return {
|
return {
|
||||||
"seek": self.seek,
|
"seek": self.seek,
|
||||||
"start": start,
|
"start": start,
|
||||||
"end": end,
|
"end": end,
|
||||||
"text": _tokenizer.decode(text_tokens),
|
"text": self.tokenizer.decode(text_tokens),
|
||||||
"tokens": _tokens,
|
"tokens": _tokens,
|
||||||
"temperature": result.temperature,
|
"temperature": result.temperature,
|
||||||
"avg_logprob": result.avg_logprob,
|
"avg_logprob": result.avg_logprob,
|
||||||
@ -371,8 +368,6 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
self, current_segments: List[dict], segment_size: int,
|
self, current_segments: List[dict], segment_size: int,
|
||||||
single_timestamp_ending: bool, tokens: torch.Tensor,
|
single_timestamp_ending: bool, tokens: torch.Tensor,
|
||||||
timestamp_tokens: torch.Tensor, result: DecodingResult):
|
timestamp_tokens: torch.Tensor, result: DecodingResult):
|
||||||
_tokenizer = self.tokenizer
|
|
||||||
assert _tokenizer is not None
|
|
||||||
consecutive = torch.where(
|
consecutive = torch.where(
|
||||||
timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||||
consecutive.add_(1)
|
consecutive.add_(1)
|
||||||
@ -387,10 +382,10 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
sliced_tokens = tokens[last_slice:current_slice]
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
start_timestamp_pos = (
|
start_timestamp_pos = (
|
||||||
sliced_tokens[0].item() -
|
sliced_tokens[0].item() -
|
||||||
_tokenizer.timestamp_begin)
|
self.tokenizer.timestamp_begin)
|
||||||
end_timestamp_pos = (
|
end_timestamp_pos = (
|
||||||
sliced_tokens[-1].item() -
|
sliced_tokens[-1].item() -
|
||||||
_tokenizer.timestamp_begin)
|
self.tokenizer.timestamp_begin)
|
||||||
current_segments.append(
|
current_segments.append(
|
||||||
self.new_segment(
|
self.new_segment(
|
||||||
start=self.time_offset + \
|
start=self.time_offset + \
|
||||||
@ -412,17 +407,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
# timestamp
|
# timestamp
|
||||||
last_timestamp_pos = (
|
last_timestamp_pos = (
|
||||||
tokens[last_slice - 1].item() -
|
tokens[last_slice - 1].item() -
|
||||||
_tokenizer.timestamp_begin)
|
self.tokenizer.timestamp_begin)
|
||||||
self.seek += last_timestamp_pos * self.input_stride
|
self.seek += last_timestamp_pos * self.input_stride
|
||||||
else:
|
else:
|
||||||
duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
if len(timestamps) > 0 and \
|
if len(timestamps) > 0 and \
|
||||||
timestamps[-1].item() != _tokenizer.timestamp_begin:
|
timestamps[-1].item() != self.tokenizer.timestamp_begin:
|
||||||
# no consecutive timestamps but it has a timestamp; use the last
|
# no consecutive timestamps but it has a timestamp; use the last
|
||||||
# one.
|
# one.
|
||||||
last_timestamp_pos = \
|
last_timestamp_pos = \
|
||||||
timestamps[-1].item() - _tokenizer.timestamp_begin
|
timestamps[-1].item() - self.tokenizer.timestamp_begin
|
||||||
duration = last_timestamp_pos * self.time_precision
|
duration = last_timestamp_pos * self.time_precision
|
||||||
|
|
||||||
current_segments.append(self.new_segment(
|
current_segments.append(self.new_segment(
|
||||||
@ -569,10 +564,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
current_segments: List[dict] = []
|
current_segments: List[dict] = []
|
||||||
|
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
_tokenizer = self.tokenizer
|
|
||||||
assert _tokenizer is not None
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(
|
timestamp_tokens: torch.Tensor = tokens.ge(
|
||||||
_tokenizer.timestamp_begin)
|
self.tokenizer.timestamp_begin)
|
||||||
single_timestamp_ending = (
|
single_timestamp_ending = (
|
||||||
timestamp_tokens[-2:].tolist() == [False, True])
|
timestamp_tokens[-2:].tolist() == [False, True])
|
||||||
|
|
||||||
@ -615,19 +608,22 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
|||||||
# do not feed the prompt tokens if a high temperature was used
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
self.prompt_reset_since = len(self.all_tokens)
|
self.prompt_reset_since = len(self.all_tokens)
|
||||||
|
|
||||||
|
self.reporthook()
|
||||||
|
|
||||||
if single_pass:
|
if single_pass:
|
||||||
break
|
break
|
||||||
|
|
||||||
_tokenizer = self.tokenizer
|
self.result = dict(
|
||||||
assert _tokenizer is not None
|
|
||||||
res = dict(
|
|
||||||
segments=self.all_segments, language=self.language,
|
segments=self.all_segments, language=self.language,
|
||||||
text=_tokenizer.decode(
|
text=self.tokenizer.decode(
|
||||||
self.all_tokens[len(self.initial_prompt_tokens):]))
|
self.all_tokens[len(self.initial_prompt_tokens):]))
|
||||||
self.latest = None
|
self.latest = None
|
||||||
return res
|
return self.result
|
||||||
|
|
||||||
def restore(self, offset: int):
|
def reporthook(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def restore(self, offset: int) -> None:
|
||||||
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
|
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
|
||||||
while len(self.all_segments) > 0 and (
|
while len(self.all_segments) > 0 and (
|
||||||
self.all_segments[-1]["start"] >= seconds
|
self.all_segments[-1]["start"] >= seconds
|
||||||
@ -652,16 +648,78 @@ class InMemoryAudio(AudioFile):
|
|||||||
def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
return InMemoryAudio(fname=audio).sequential()
|
return InMemoryAudio(fname=audio).sequential()
|
||||||
if isinstance(audio, np.dtype):
|
if isinstance(audio, np.ndarray):
|
||||||
return torch.from_numpy(audio)
|
return torch.from_numpy(audio)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalTranscriber(Transcriber):
|
||||||
|
exact: bool = True
|
||||||
|
chlen: float = CHUNK_LENGTH
|
||||||
|
async def process(self, stream: ArrayStream, **kw) -> dict:
|
||||||
|
data = await stream.request(self.chlen, self.exact)
|
||||||
|
while data.shape[-1] > 0:
|
||||||
|
self(data, stream.offset, True)
|
||||||
|
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \
|
||||||
|
/ FRAMES_PER_SECOND + CHUNK_LENGTH
|
||||||
|
data = await stream.request(t, self.exact)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressTranscriber(MinimalTranscriber):
|
||||||
|
def __init__(self, *a, duration: Optional[float] = None, **kw):
|
||||||
|
super().__init__(*a, **kw)
|
||||||
|
self.duration, self.progress = duration, 0
|
||||||
|
|
||||||
|
def __call__(self, *a, **kw) -> dict:
|
||||||
|
if self._pbar is None:
|
||||||
|
try:
|
||||||
|
return super().__call__(*a, **kw)
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
else:
|
||||||
|
return super().__call__(*a, **kw)
|
||||||
|
|
||||||
|
@PassthroughProperty(None).property
|
||||||
|
def pbar(self):
|
||||||
|
if self._pbar is None:
|
||||||
|
n = self.latest.shape[-1] if self.duration is None \
|
||||||
|
else -int(self.duration * -FRAMES_PER_SECOND)
|
||||||
|
self._pbar = tqdm.tqdm(
|
||||||
|
total=n, unit="frames", disable=self.verbose is not False)
|
||||||
|
self._pbar.__enter__()
|
||||||
|
return self._pbar
|
||||||
|
|
||||||
|
def reporthook(self) -> None:
|
||||||
|
update_to = min(self._seek, self.frame_offset + self.latest.shape[-1])
|
||||||
|
self.pbar.update(update_to - self.progress)
|
||||||
|
self.progress = update_to
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.pbar.__exit__(None, None, None)
|
||||||
|
|
||||||
|
async def process(self, stream: ArrayStream, **kw) -> dict:
|
||||||
|
self.pbar
|
||||||
|
try:
|
||||||
|
return await super().process(stream, **kw)
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
async def progressive(self, stream: AudioFile, **kw) -> dict:
|
||||||
|
self.duration = stream.duration
|
||||||
|
return await self.process(stream, **kw)
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
**kw):
|
**kw):
|
||||||
return Transcriber(model, **kw)(audio_tensor(audio))
|
return ProgressTranscriber(model, **kw)(audio_tensor(audio))
|
||||||
|
|
||||||
|
|
||||||
|
def buffered_transcribe(model: "Whisper", audio: str, **kw):
|
||||||
|
transcriber = ProgressTranscriber(model, **kw)
|
||||||
|
return asyncio.run(transcriber.progressive(AudioFile(fname=audio)))
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@ -712,6 +770,7 @@ def cli():
|
|||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
|
parser.add_argument("--buffered", type=str2bool, default=False, help="whether to load the audio data on demand instead of all at once")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -741,6 +800,7 @@ def cli():
|
|||||||
from . import load_model
|
from . import load_model
|
||||||
|
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
transcriber = buffered_transcribe if args.pop("buffered") else transcribe
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
word_options = [
|
word_options = [
|
||||||
@ -760,7 +820,7 @@ def cli():
|
|||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
try:
|
try:
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcriber(model, audio_path, temperature=temperature, **args)
|
||||||
writer(result, audio_path, **writer_args)
|
writer(result, audio_path, **writer_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user