mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
feat: improve language detection
This commit is contained in:
parent
6dea21fd7f
commit
ef14efdc54
@ -42,6 +42,8 @@ def transcribe(
|
|||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
|
language_threshold: Optional[float] = 0.6,
|
||||||
|
language_detection_segments: int = 1,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
@ -78,6 +80,13 @@ def transcribe(
|
|||||||
If the no_speech probability is higher than this value AND the average log probability
|
If the no_speech probability is higher than this value AND the average log probability
|
||||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||||
|
|
||||||
|
language_threshold: float
|
||||||
|
If the maximum probability of the language tokens is higher than this value, the language is
|
||||||
|
detected
|
||||||
|
|
||||||
|
language_detection_segments: int
|
||||||
|
Number of segments to consider for the language detection
|
||||||
|
|
||||||
condition_on_previous_text: bool
|
condition_on_previous_text: bool
|
||||||
if True, the previous output of the model is provided as a prompt for the next window;
|
if True, the previous output of the model is provided as a prompt for the next window;
|
||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||||
@ -126,12 +135,27 @@ def transcribe(
|
|||||||
decode_options["language"] = "en"
|
decode_options["language"] = "en"
|
||||||
else:
|
else:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print("Detecting language. Use `--language` to specify the language")
|
||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
if language_detection_segments is None or language_detection_segments < 1:
|
||||||
)
|
language_detection_segments = 1
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
seek = 0
|
||||||
_, probs = model.detect_language(mel_segment)
|
languages = []
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
while seek < content_frames and seek < N_FRAMES * language_detection_segments:
|
||||||
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||||
|
_, probs = model.detect_language(segment)
|
||||||
|
lang = max(probs, key=probs.get)
|
||||||
|
lang_prob = probs[lang]
|
||||||
|
if language_threshold is not None and lang_prob > language_threshold:
|
||||||
|
decode_options["language"] = lang
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
languages.append(lang)
|
||||||
|
seek += segment.shape[-1]
|
||||||
|
else:
|
||||||
|
# If no language detected for all segments, the majority vote of the highest projected
|
||||||
|
# languages for all segments is used to determine the language.
|
||||||
|
decode_options["language"] = max(set(languages), key=languages.count)
|
||||||
|
|
||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
print(
|
print(
|
||||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||||
@ -382,6 +406,8 @@ def cli():
|
|||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||||
|
parser.add_argument("--language_threshold", type=optional_float, default=None, help="if the maximum probability of the language tokens is higher than this value, the language is detected")
|
||||||
|
parser.add_argument("--language_detection_segments", type=int, default=1, help="number of segments to consider for the language detection")
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user