Add compression_ratio_hallucination_threshold

Add compression_ratio_hallucination_threshold to Discard High Compression Ratio Segments in transcribe()

https://github.com/openai/whisper/discussions/2420
This commit is contained in:
Alexander Kuznetsov 2024-11-01 20:16:57 +03:00 committed by GitHub
parent 5979f03701
commit bb8c47519d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,11 +42,11 @@ def transcribe(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
compression_ratio_halucination_threshold: Optional[float] = 3,
logprob_threshold: Optional[float] = -1.0, logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
@ -76,6 +76,9 @@ def transcribe(
compression_ratio_threshold: float compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed If the gzip compression ratio is above this value, treat as failed
compression_ratio_halcination_threshold: float
If the gzip compression ratio is above this value after all attempts to decode, treat as a halucination and skip
logprob_threshold: float logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed If the average log probability over sampled tokens is below this value, treat as failed
@ -205,7 +208,7 @@ def transcribe(
compression_ratio_threshold is not None compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold and decode_result.compression_ratio > compression_ratio_threshold
): ):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive <-- We can inprove it...
if ( if (
logprob_threshold is not None logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold and decode_result.avg_logprob < logprob_threshold
@ -216,6 +219,13 @@ def transcribe(
and decode_result.no_speech_prob > no_speech_threshold and decode_result.no_speech_prob > no_speech_threshold
): ):
needs_fallback = False # silence needs_fallback = False # silence
if (
compression_ratio_halucination_threshold is not None
and decode_result.compression_ratio > compression_ratio_halucination_threshold
and t == temperatures[-1]
):
# Discard the segment
continue # Skip to the next segment
if not needs_fallback: if not needs_fallback:
break break