diff --git a/tests/fdr.mp3 b/tests/fdr.mp3 new file mode 100644 index 0000000..ecef762 Binary files /dev/null and b/tests/fdr.mp3 differ diff --git a/tests/test_progress_callback.py b/tests/test_progress_callback.py new file mode 100644 index 0000000..266aab3 --- /dev/null +++ b/tests/test_progress_callback.py @@ -0,0 +1,27 @@ +import os + +import pytest +import torch + +import whisper + + +def test_progress_callback(): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = whisper.load_model("tiny").to(device) + audio_path = os.path.join(os.path.dirname(__file__), "fdr.mp3") + + progress = [] + + def callback(progress_data): + progress.append(progress_data) + + model.transcribe( + audio_path, + language="en", + verbose=False, # purely for visualization purposes, not needed for the progress callback + progress_callback=callback + ) + print(progress) + assert len(progress) > 0 + assert progress[-1] == 100.0 diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc36..482ccba 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -40,6 +40,7 @@ def transcribe( audio: Union[str, np.ndarray, torch.Tensor], *, verbose: Optional[bool] = None, + progress_callback: Optional[callable] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, @@ -138,6 +139,7 @@ def transcribe( # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-1] - N_FRAMES + curr_frames = 0 content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) if decode_options.get("language", None) is None: @@ -262,7 +264,7 @@ def transcribe( # show the progress bar when verbose is False (if True, transcribed text will be printed) with tqdm.tqdm( - total=content_frames, unit="frames", disable=verbose is not False + total=content_frames, unit="frame", disable=verbose is not False ) as pbar: last_speech_timestamp = 0.0 # NOTE: This loop is obscurely flattened to make the diff readable. @@ -505,7 +507,11 @@ def transcribe( prompt_reset_since = len(all_tokens) # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) + frames_processed = min(content_frames, seek) - previous_seek + if progress_callback is not None: + curr_frames = frames_processed + curr_frames + progress_callback(curr_frames / content_frames * 100) + pbar.update(frames_processed) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),