mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
add TranscribeProgressReceiver for update monitoring
Current `transcribe` API only outputs the progress and transcribed texts on stdout. Callers can only access the result after the whole transcription is done, and they need to hijack `tqdm` interface to get the realtime transcription progress. This commit adds a simple interface that can be passed as a parameter in `transcribe` so the API users don't need to fallback to above hacks or low-level APIs for this need. Signed-off-by: Austin Chang <austin880625@gmail.com>
This commit is contained in:
parent
25639fc17d
commit
bdbe6bfb47
@ -7,18 +7,37 @@ import whisper
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
|
||||
def start(self, total: int):
|
||||
self.result = ""
|
||||
self.total = total
|
||||
self.progress = 0
|
||||
return self
|
||||
def update_line(self, start: float, end: float, text: str):
|
||||
self.result += text
|
||||
def update(self, n):
|
||||
self.progress += n
|
||||
def get_result(self):
|
||||
return self.result
|
||||
def verify_total(self):
|
||||
return self.total == self.progress
|
||||
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
def test_transcribe(model_name: str):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(model_name).to(device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
receiver = TestingProgressReceiver()
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True,
|
||||
progress_receiver=receiver
|
||||
)
|
||||
assert receiver.verify_total()
|
||||
assert result["language"] == "en"
|
||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||
assert result["text"] == receiver.get_result()
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
|
||||
@ -11,7 +11,7 @@ from tqdm import tqdm
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .transcribe import TranscribeProgressReceiver, transcribe
|
||||
from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
|
||||
@ -2,7 +2,7 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -34,12 +34,57 @@ from .utils import (
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
class TranscribeProgressReceiver:
|
||||
"""
|
||||
A class that allows external classes to inherit and handle transcription progress in customized
|
||||
manners.
|
||||
"""
|
||||
def start(self, total: int) -> Self:
|
||||
"""
|
||||
The method is called when the transcription starts with integral `total` parameter in frames.
|
||||
In most case this method should return `self`
|
||||
"""
|
||||
return self
|
||||
def update(self, n: int):
|
||||
"""
|
||||
The `update` method is called with increment `n` in frames whenever a segment is transcribed.
|
||||
"""
|
||||
pass
|
||||
def update_line(self, start: float, end: float, text: str):
|
||||
"""
|
||||
It is called whenever a segment is transcribed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: float
|
||||
The floating point start time of the segment in seconds
|
||||
|
||||
end: float
|
||||
The floating point end time of the segment in seconds
|
||||
|
||||
text: str
|
||||
The transcribed text
|
||||
"""
|
||||
pass
|
||||
def __enter__(self) -> Self:
|
||||
"""
|
||||
Inherit this method if resources allocation is needed at the start of the transcription.
|
||||
In most cases this method should return `self`
|
||||
"""
|
||||
return self
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
"""
|
||||
Inherit this method if resources need to be released when the transcription is finished or
|
||||
terminated.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(),
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
@ -253,7 +298,8 @@ def transcribe(
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
) as pbar, \
|
||||
progress_receiver.start(total=content_frames) as ext_progress:
|
||||
last_speech_timestamp = 0.0
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
@ -459,10 +505,11 @@ def transcribe(
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
ext_progress.update_line(start, end, make_safe(text))
|
||||
if verbose:
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
@ -490,6 +537,7 @@ def transcribe(
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
ext_progress.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user