add enable_generator parameter

This commit is contained in:
jhj0517 2024-04-12 23:32:03 +09:00
parent ba3f3cd54b
commit 541c7e360f

View File

@ -51,6 +51,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
enable_generator: bool = False,
**decode_options, **decode_options,
): ):
""" """
@ -113,6 +114,10 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds) When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected when a possible hallucination is detected
enable_generator: bool
Whether to use a generator for output. If True, yields incremental results as a generator.
If False, returns the complete transcription results as a dictionary.
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -491,6 +496,13 @@ def transcribe(
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)
if enable_generator:
yield dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments, segments=all_segments,