Merge 600506583baee13c2feab6f7900f2dbee7cecba7 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Jonathan Baudanza 2025-06-27 02:27:48 +00:00 committed by GitHub
commit 44b8a4cf0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 2 deletions

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
@torch.no_grad() @torch.no_grad()
def detect_language( def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None
) -> Tuple[Tensor, List[dict]]: ) -> Tuple[Tensor, List[dict]]:
""" """
Detect the spoken language in the audio, and return them as list of strings, along with the ids Detect the spoken language in the audio, and return them as list of strings, along with the ids
@ -56,6 +56,14 @@ def detect_language(
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0] logits = model.logits(x, mel)[:, 0]
# apply language_bias to logits
if language_bias:
biases = torch.zeros(logits.size(1), device=logits.device)
for lang, bias in language_bias.items():
token = tokenizer.to_language_token(lang)
biases[token] = bias
logits += biases
# collect detected languages; suppress all non-language tokens # collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False mask[list(tokenizer.all_language_tokens)] = False

View File

@ -52,6 +52,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
language_bias: Optional[dict[str,float]] = None,
**decode_options, **decode_options,
): ):
""" """
@ -119,6 +120,10 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds) When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected when a possible hallucination is detected
language_bias: Optional[dict[str,float]] = None
A dictionary of language codes to positive or negative float values. These values will be
applied to the language detection logits before choosing the language.
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -149,7 +154,7 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
) )
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment) _, probs = model.detect_language(mel_segment, language_bias=language_bias)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
print( print(