feat: improve language detection

This commit is contained in:
petrosvav 2023-03-24 16:02:48 +02:00
parent 6dea21fd7f
commit ef14efdc54

View File

@ -42,6 +42,8 @@ def transcribe(
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
language_threshold: Optional[float] = 0.6,
language_detection_segments: int = 1,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
@ -78,6 +80,13 @@ def transcribe(
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
language_threshold: float
If the maximum probability of the language tokens is higher than this value, the language is
detected
language_detection_segments: int
Number of segments to consider for the language detection
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
@ -126,12 +135,27 @@ def transcribe(
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
print("Detecting language. Use `--language` to specify the language")
if language_detection_segments is None or language_detection_segments < 1:
language_detection_segments = 1
seek = 0
languages = []
while seek < content_frames and seek < N_FRAMES * language_detection_segments:
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
lang = max(probs, key=probs.get)
lang_prob = probs[lang]
if language_threshold is not None and lang_prob > language_threshold:
decode_options["language"] = lang
break
else:
languages.append(lang)
seek += segment.shape[-1]
else:
# If no language detected for all segments, the majority vote of the highest projected
# languages for all segments is used to determine the language.
decode_options["language"] = max(set(languages), key=languages.count)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
@ -382,6 +406,8 @@ def cli():
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--language_threshold", type=optional_float, default=None, help="if the maximum probability of the language tokens is higher than this value, the language is detected")
parser.add_argument("--language_detection_segments", type=int, default=1, help="number of segments to consider for the language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")