mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 00:03:40 +00:00
Merge 600506583baee13c2feab6f7900f2dbee7cecba7 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
44b8a4cf0f
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None
|
||||
) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
@ -56,6 +56,14 @@ def detect_language(
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# apply language_bias to logits
|
||||
if language_bias:
|
||||
biases = torch.zeros(logits.size(1), device=logits.device)
|
||||
for lang, bias in language_bias.items():
|
||||
token = tokenizer.to_language_token(lang)
|
||||
biases[token] = bias
|
||||
logits += biases
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
|
||||
@ -52,6 +52,7 @@ def transcribe(
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
language_bias: Optional[dict[str,float]] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -119,6 +120,10 @@ def transcribe(
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
language_bias: Optional[dict[str,float]] = None
|
||||
A dictionary of language codes to positive or negative float values. These values will be
|
||||
applied to the language detection logits before choosing the language.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
@ -149,7 +154,7 @@ def transcribe(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
_, probs = model.detect_language(mel_segment, language_bias=language_bias)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user