diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc36..eb89e58 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -44,6 +44,8 @@ def transcribe( compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, + language_threshold: Optional[float] = 0.6, + language_detection_segments: int = 1, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, carry_initial_prompt: bool = False, @@ -83,6 +85,13 @@ def transcribe( If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `logprob_threshold`, consider the segment as silent + language_threshold: float + If the maximum probability of the language tokens is higher than this value, the language is + detected + + language_detection_segments: int + Number of segments to consider for the language detection + condition_on_previous_text: bool if True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to @@ -145,12 +154,27 @@ def transcribe( decode_options["language"] = "en" else: if verbose: - print( - "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" - ) - mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) + print("Detecting language. Use `--language` to specify the language") + if language_detection_segments is None or language_detection_segments < 1: + language_detection_segments = 1 + seek = 0 + languages = [] + while seek < content_frames and seek < N_FRAMES * language_detection_segments: + segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) + _, probs = model.detect_language(segment) + lang = max(probs, key=probs.get) + lang_prob = probs[lang] + if language_threshold is not None and lang_prob > language_threshold: + decode_options["language"] = lang + break + else: + languages.append(lang) + seek += segment.shape[-1] + else: + # If no language detected for all segments, the majority vote of the highest projected + # languages for all segments is used to determine the language. + decode_options["language"] = max(set(languages), key=languages.count) + if verbose is not None: print( f"Detected language: {LANGUAGES[decode_options['language']].title()}" @@ -536,6 +560,8 @@ def cli(): parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") + parser.add_argument("--language_threshold", type=optional_float, default=None, help="if the maximum probability of the language tokens is higher than this value, the language is detected") + parser.add_argument("--language_detection_segments", type=int, default=1, help="number of segments to consider for the language detection") parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")