add TranscribeProgressReceiver for update monitoring

Current `transcribe` API only outputs the progress and transcribed texts
on stdout. Callers can only access the result after the whole
transcription is done, and they need to hijack `tqdm` interface to get
the realtime transcription progress. This commit adds a simple interface
that can be passed as a parameter in `transcribe` so the API users don't
need to fallback to above hacks or low-level APIs for this need.

Signed-off-by: Austin Chang <austin880625@gmail.com>
This commit is contained in:
Austin Chang 2024-10-19 11:20:55 +08:00
parent 25639fc17d
commit bdbe6bfb47
3 changed files with 75 additions and 8 deletions

View File

@ -7,18 +7,37 @@ import whisper
from whisper.tokenizer import get_tokenizer
class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
def start(self, total: int):
self.result = ""
self.total = total
self.progress = 0
return self
def update_line(self, start: float, end: float, text: str):
self.result += text
def update(self, n):
self.progress += n
def get_result(self):
return self.result
def verify_total(self):
return self.total == self.progress
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
receiver = TestingProgressReceiver()
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
audio_path, language=language, temperature=0.0, word_timestamps=True,
progress_receiver=receiver
)
assert receiver.verify_total()
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
assert result["text"] == receiver.get_result()
transcription = result["text"].lower()
assert "my fellow americans" in transcription

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .transcribe import TranscribeProgressReceiver, transcribe
from .version import __version__
_MODELS = {

View File

@ -2,7 +2,7 @@ import argparse
import os
import traceback
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self
import numpy as np
import torch
@ -34,12 +34,57 @@ from .utils import (
if TYPE_CHECKING:
from .model import Whisper
class TranscribeProgressReceiver:
"""
A class that allows external classes to inherit and handle transcription progress in customized
manners.
"""
def start(self, total: int) -> Self:
"""
The method is called when the transcription starts with integral `total` parameter in frames.
In most case this method should return `self`
"""
return self
def update(self, n: int):
"""
The `update` method is called with increment `n` in frames whenever a segment is transcribed.
"""
pass
def update_line(self, start: float, end: float, text: str):
"""
It is called whenever a segment is transcribed.
Parameters
----------
start: float
The floating point start time of the segment in seconds
end: float
The floating point end time of the segment in seconds
text: str
The transcribed text
"""
pass
def __enter__(self) -> Self:
"""
Inherit this method if resources allocation is needed at the start of the transcription.
In most cases this method should return `self`
"""
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
"""
Inherit this method if resources need to be released when the transcription is finished or
terminated.
"""
pass
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(),
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
@ -253,7 +298,8 @@ def transcribe(
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
) as pbar, \
progress_receiver.start(total=content_frames) as ext_progress:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
@ -459,10 +505,11 @@ def transcribe(
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
ext_progress.update_line(start, end, make_safe(text))
if verbose:
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
@ -490,6 +537,7 @@ def transcribe(
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
ext_progress.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),