mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
add TranscribeProgressReceiver for update monitoring
Current `transcribe` API only outputs the progress and transcribed texts on stdout. Callers can only access the result after the whole transcription is done, and they need to hijack `tqdm` interface to get the realtime transcription progress. This commit adds a simple interface that can be passed as a parameter in `transcribe` so the API users don't need to fallback to above hacks or low-level APIs for this need. Signed-off-by: Austin Chang <austin880625@gmail.com>
This commit is contained in:
parent
25639fc17d
commit
bdbe6bfb47
@ -7,18 +7,37 @@ import whisper
|
|||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
|
||||||
|
def start(self, total: int):
|
||||||
|
self.result = ""
|
||||||
|
self.total = total
|
||||||
|
self.progress = 0
|
||||||
|
return self
|
||||||
|
def update_line(self, start: float, end: float, text: str):
|
||||||
|
self.result += text
|
||||||
|
def update(self, n):
|
||||||
|
self.progress += n
|
||||||
|
def get_result(self):
|
||||||
|
return self.result
|
||||||
|
def verify_total(self):
|
||||||
|
return self.total == self.progress
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
def test_transcribe(model_name: str):
|
def test_transcribe(model_name: str):
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = whisper.load_model(model_name).to(device)
|
model = whisper.load_model(model_name).to(device)
|
||||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||||
|
receiver = TestingProgressReceiver()
|
||||||
|
|
||||||
language = "en" if model_name.endswith(".en") else None
|
language = "en" if model_name.endswith(".en") else None
|
||||||
result = model.transcribe(
|
result = model.transcribe(
|
||||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
audio_path, language=language, temperature=0.0, word_timestamps=True,
|
||||||
|
progress_receiver=receiver
|
||||||
)
|
)
|
||||||
|
assert receiver.verify_total()
|
||||||
assert result["language"] == "en"
|
assert result["language"] == "en"
|
||||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||||
|
assert result["text"] == receiver.get_result()
|
||||||
|
|
||||||
transcription = result["text"].lower()
|
transcription = result["text"].lower()
|
||||||
assert "my fellow americans" in transcription
|
assert "my fellow americans" in transcription
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
from .model import ModelDimensions, Whisper
|
from .model import ModelDimensions, Whisper
|
||||||
from .transcribe import transcribe
|
from .transcribe import TranscribeProgressReceiver, transcribe
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -34,12 +34,57 @@ from .utils import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
|
class TranscribeProgressReceiver:
|
||||||
|
"""
|
||||||
|
A class that allows external classes to inherit and handle transcription progress in customized
|
||||||
|
manners.
|
||||||
|
"""
|
||||||
|
def start(self, total: int) -> Self:
|
||||||
|
"""
|
||||||
|
The method is called when the transcription starts with integral `total` parameter in frames.
|
||||||
|
In most case this method should return `self`
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
def update(self, n: int):
|
||||||
|
"""
|
||||||
|
The `update` method is called with increment `n` in frames whenever a segment is transcribed.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def update_line(self, start: float, end: float, text: str):
|
||||||
|
"""
|
||||||
|
It is called whenever a segment is transcribed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
start: float
|
||||||
|
The floating point start time of the segment in seconds
|
||||||
|
|
||||||
|
end: float
|
||||||
|
The floating point end time of the segment in seconds
|
||||||
|
|
||||||
|
text: str
|
||||||
|
The transcribed text
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""
|
||||||
|
Inherit this method if resources allocation is needed at the start of the transcription.
|
||||||
|
In most cases this method should return `self`
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||||
|
"""
|
||||||
|
Inherit this method if resources need to be released when the transcription is finished or
|
||||||
|
terminated.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
*,
|
*,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(),
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
@ -253,7 +298,8 @@ def transcribe(
|
|||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||||
with tqdm.tqdm(
|
with tqdm.tqdm(
|
||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar, \
|
||||||
|
progress_receiver.start(total=content_frames) as ext_progress:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
# A later commit should turn this into a simpler nested loop.
|
# A later commit should turn this into a simpler nested loop.
|
||||||
@ -459,10 +505,11 @@ def transcribe(
|
|||||||
if last_word_end is not None:
|
if last_word_end is not None:
|
||||||
last_speech_timestamp = last_word_end
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||||
|
ext_progress.update_line(start, end, make_safe(text))
|
||||||
|
if verbose:
|
||||||
print(make_safe(line))
|
print(make_safe(line))
|
||||||
|
|
||||||
# if a segment is instantaneous or does not contain text, clear it
|
# if a segment is instantaneous or does not contain text, clear it
|
||||||
@ -490,6 +537,7 @@ def transcribe(
|
|||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(min(content_frames, seek) - previous_seek)
|
pbar.update(min(content_frames, seek) - previous_seek)
|
||||||
|
ext_progress.update(min(content_frames, seek) - previous_seek)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user