Merge 600506583baee13c2feab6f7900f2dbee7cecba7 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
Jonathan Baudanza 2025-06-27 02:27:48 +00:00 committed by GitHub
commit 44b8a4cf0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 2 deletions

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
@torch.no_grad()
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
@ -56,6 +56,14 @@ def detect_language(
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# apply language_bias to logits
if language_bias:
biases = torch.zeros(logits.size(1), device=logits.device)
for lang, bias in language_bias.items():
token = tokenizer.to_language_token(lang)
biases[token] = bias
logits += biases
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False

View File

@ -52,6 +52,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
language_bias: Optional[dict[str,float]] = None,
**decode_options,
):
"""
@ -119,6 +120,10 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
language_bias: Optional[dict[str,float]] = None
A dictionary of language codes to positive or negative float values. These values will be
applied to the language detection logits before choosing the language.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -149,7 +154,7 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
_, probs = model.detect_language(mel_segment, language_bias=language_bias)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(