transcribe can take a callback that is fn(int, int, float) -> None

Takes the current position, the total steps, and an estimate of the remaining time in seconds
This commit is contained in:
Millan Kumar 2024-12-11 18:59:52 -05:00
parent 90db0de189
commit abf6778935

View File

@ -2,7 +2,7 @@ import argparse
import os
import traceback
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@ -52,6 +52,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
callback: Optional[Callable[[int, int, float], None]] = None,
**decode_options,
):
"""
@ -119,6 +120,10 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
callback: Optional[Callable[int, int, float]] = None,
After each step in the transcription process, call the callback function with
the arguments current posistion, total frames, estimated time to finish in seconds
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -504,8 +509,17 @@ def transcribe(
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
total_position = min(content_frames, seek)
increase = total_position - previous_seek
if callback is not None:
rate = pbar.format_dict["rate"]
remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
callback(total_position, content_frames, remaining)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
pbar.update(increase)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),