From 541c7e360f24f6d0f9fb0eb7cd178c6b50684caa Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Fri, 12 Apr 2024 23:32:03 +0900 Subject: [PATCH] add enable_generator parameter --- whisper/transcribe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..01a94d4 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + enable_generator: bool = False, **decode_options, ): """ @@ -113,6 +114,10 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + enable_generator: bool + Whether to use a generator for output. If True, yields incremental results as a generator. + If False, returns the complete transcription results as a dictionary. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -491,6 +496,13 @@ def transcribe( # update progress bar pbar.update(min(content_frames, seek) - previous_seek) + if enable_generator: + yield dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + segments=all_segments, + language=language, + ) + return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments,