Merge 5f850028a771439cd61a6c6f290f34f0334d2c66 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Shashank Prasanna 2025-07-16 23:15:52 -04:00 committed by GitHub
commit c0287e9184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 2 deletions

BIN
tests/fdr.mp3 Normal file

Binary file not shown.

View File

@ -0,0 +1,27 @@
import os
import pytest
import torch
import whisper
def test_progress_callback():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model("tiny").to(device)
audio_path = os.path.join(os.path.dirname(__file__), "fdr.mp3")
progress = []
def callback(progress_data):
progress.append(progress_data)
model.transcribe(
audio_path,
language="en",
verbose=False, # purely for visualization purposes, not needed for the progress callback
progress_callback=callback
)
print(progress)
assert len(progress) > 0
assert progress[-1] == 100.0

View File

@ -40,6 +40,7 @@ def transcribe(
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
*, *,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
progress_callback: Optional[callable] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0, logprob_threshold: Optional[float] = -1.0,
@ -138,6 +139,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
curr_frames = 0
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
@ -262,7 +264,7 @@ def transcribe(
# show the progress bar when verbose is False (if True, transcribed text will be printed) # show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm( with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frame", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable. # NOTE: This loop is obscurely flattened to make the diff readable.
@ -505,7 +507,11 @@ def transcribe(
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) frames_processed = min(content_frames, seek) - previous_seek
if progress_callback is not None:
curr_frames = frames_processed + curr_frames
progress_callback(curr_frames / content_frames * 100)
pbar.update(frames_processed)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),