Merge bab8297000119678733d80e2cca1c1e1fe12e3f4 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Alexander Kuznetsov 2025-06-26 00:06:39 -04:00 committed by GitHub
commit cc21fe3b9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,6 +42,7 @@ def transcribe(
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
compression_ratio_hallucination_threshold: Optional[float] = 3,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
@ -76,6 +77,9 @@ def transcribe(
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
compression_ratio_hallucination_threshold: float
If the gzip compression ratio is above this value after all attempts to decode, treat as a hallucination and skip
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
@ -218,6 +222,13 @@ def transcribe(
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = False # silence
if (
compression_ratio_hallucination_threshold is not None
and decode_result.compression_ratio > compression_ratio_hallucination_threshold
and t == temperatures[-1]
):
# Discard the segment
return None # Skip to the next segment
if not needs_fallback:
break
@ -293,6 +304,14 @@ def transcribe(
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
if result is None:
if verbose:
print(
f"Discarding segment {format_timestamp(time_offset)} - {format_timestamp(time_offset + segment_duration)} "
"due to high compression ratio."
)
seek += segment_size # Move to the next segment
continue # Skip processing this segment
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None: