From a72a04414bf7bb1b5178a366ef38fb2272ad3754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Boyer?= Date: Sun, 21 May 2023 16:30:46 +0200 Subject: [PATCH 1/2] Return best option on fallback When the avg_logprob condition isn't satisfied and the result is re-computed with a greater temperature, the best option is returned --- whisper/transcribe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ff73a55..9ca47c3 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -148,7 +148,7 @@ def transcribe( temperatures = ( [temperature] if isinstance(temperature, (int, float)) else temperature ) - decode_result = None + results = {} for t in temperatures: kwargs = {**decode_options} @@ -163,6 +163,8 @@ def transcribe( options = DecodingOptions(**kwargs, temperature=t) decode_result = model.decode(segment, options) + results[t] = decode_result + needs_fallback = False if ( compression_ratio_threshold is not None @@ -182,7 +184,7 @@ def transcribe( if not needs_fallback: break - return decode_result + return max(results.values(), key=lambda r: r.avg_logprob) seek = 0 input_stride = exact_div( From f677284d118eb15dc9d0c56a6034ae47181575ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Boyer?= Date: Sun, 21 May 2023 16:58:12 +0200 Subject: [PATCH 2/2] Return the best only if all fallbacks failed --- whisper/transcribe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 9ca47c3..d92d7da 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -148,6 +148,7 @@ def transcribe( temperatures = ( [temperature] if isinstance(temperature, (int, float)) else temperature ) + decode_result = None results = {} for t in temperatures: @@ -183,8 +184,11 @@ def transcribe( needs_fallback = False # silence if not needs_fallback: break + else: + # all failed + return max(results.values(), key=lambda r: r.avg_logprob) - return max(results.values(), key=lambda r: r.avg_logprob) + return decode_result seek = 0 input_stride = exact_div(