mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Add progress callback to transcribe function and corresponding test
This commit is contained in:
parent
c0d2f624c0
commit
5f850028a7
BIN
tests/fdr.mp3
Normal file
BIN
tests/fdr.mp3
Normal file
Binary file not shown.
27
tests/test_progress_callback.py
Normal file
27
tests/test_progress_callback.py
Normal file
@ -0,0 +1,27 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
|
||||
|
||||
def test_progress_callback():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model("tiny").to(device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "fdr.mp3")
|
||||
|
||||
progress = []
|
||||
|
||||
def callback(progress_data):
|
||||
progress.append(progress_data)
|
||||
|
||||
model.transcribe(
|
||||
audio_path,
|
||||
language="en",
|
||||
verbose=False, # purely for visualization purposes, not needed for the progress callback
|
||||
progress_callback=callback
|
||||
)
|
||||
print(progress)
|
||||
assert len(progress) > 0
|
||||
assert progress[-1] == 100.0
|
||||
@ -40,6 +40,7 @@ def transcribe(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
progress_callback: Optional[callable] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
@ -138,6 +139,7 @@ def transcribe(
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
curr_frames = 0
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
@ -262,7 +264,7 @@ def transcribe(
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
total=content_frames, unit="frame", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
@ -505,7 +507,11 @@ def transcribe(
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
frames_processed = min(content_frames, seek) - previous_seek
|
||||
if progress_callback is not None:
|
||||
curr_frames = frames_processed + curr_frames
|
||||
progress_callback(curr_frames / content_frames * 100)
|
||||
pbar.update(frames_processed)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user