diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..01a94d4 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + enable_generator: bool = False, **decode_options, ): """ @@ -113,6 +114,10 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + enable_generator: bool + Whether to use a generator for output. If True, yields incremental results as a generator. + If False, returns the complete transcription results as a dictionary. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -491,6 +496,13 @@ def transcribe( # update progress bar pbar.update(min(content_frames, seek) - previous_seek) + if enable_generator: + yield dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + segments=all_segments, + language=language, + ) + return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments,