mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
feat: improve language detection
This commit is contained in:
parent
6dea21fd7f
commit
ef14efdc54
@ -42,6 +42,8 @@ def transcribe(
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
language_threshold: Optional[float] = 0.6,
|
||||
language_detection_segments: int = 1,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
@ -78,6 +80,13 @@ def transcribe(
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
language_threshold: float
|
||||
If the maximum probability of the language tokens is higher than this value, the language is
|
||||
detected
|
||||
|
||||
language_detection_segments: int
|
||||
Number of segments to consider for the language detection
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
@ -126,12 +135,27 @@ def transcribe(
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
print("Detecting language. Use `--language` to specify the language")
|
||||
if language_detection_segments is None or language_detection_segments < 1:
|
||||
language_detection_segments = 1
|
||||
seek = 0
|
||||
languages = []
|
||||
while seek < content_frames and seek < N_FRAMES * language_detection_segments:
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
lang = max(probs, key=probs.get)
|
||||
lang_prob = probs[lang]
|
||||
if language_threshold is not None and lang_prob > language_threshold:
|
||||
decode_options["language"] = lang
|
||||
break
|
||||
else:
|
||||
languages.append(lang)
|
||||
seek += segment.shape[-1]
|
||||
else:
|
||||
# If no language detected for all segments, the majority vote of the highest projected
|
||||
# languages for all segments is used to determine the language.
|
||||
decode_options["language"] = max(set(languages), key=languages.count)
|
||||
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
@ -382,6 +406,8 @@ def cli():
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
parser.add_argument("--language_threshold", type=optional_float, default=None, help="if the maximum probability of the language tokens is higher than this value, the language is detected")
|
||||
parser.add_argument("--language_detection_segments", type=int, default=1, help="number of segments to consider for the language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user