mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 08:11:11 +00:00
Merge 600506583baee13c2feab6f7900f2dbee7cecba7 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
44b8a4cf0f
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def detect_language(
|
def detect_language(
|
||||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, language_bias: Optional[dict[str,float]] = None
|
||||||
) -> Tuple[Tensor, List[dict]]:
|
) -> Tuple[Tensor, List[dict]]:
|
||||||
"""
|
"""
|
||||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||||
@ -56,6 +56,14 @@ def detect_language(
|
|||||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
logits = model.logits(x, mel)[:, 0]
|
logits = model.logits(x, mel)[:, 0]
|
||||||
|
|
||||||
|
# apply language_bias to logits
|
||||||
|
if language_bias:
|
||||||
|
biases = torch.zeros(logits.size(1), device=logits.device)
|
||||||
|
for lang, bias in language_bias.items():
|
||||||
|
token = tokenizer.to_language_token(lang)
|
||||||
|
biases[token] = bias
|
||||||
|
logits += biases
|
||||||
|
|
||||||
# collect detected languages; suppress all non-language tokens
|
# collect detected languages; suppress all non-language tokens
|
||||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
mask[list(tokenizer.all_language_tokens)] = False
|
mask[list(tokenizer.all_language_tokens)] = False
|
||||||
|
|||||||
@ -52,6 +52,7 @@ def transcribe(
|
|||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
|
language_bias: Optional[dict[str,float]] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -119,6 +120,10 @@ def transcribe(
|
|||||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
when a possible hallucination is detected
|
when a possible hallucination is detected
|
||||||
|
|
||||||
|
language_bias: Optional[dict[str,float]] = None
|
||||||
|
A dictionary of language codes to positive or negative float values. These values will be
|
||||||
|
applied to the language detection logits before choosing the language.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -149,7 +154,7 @@ def transcribe(
|
|||||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||||
)
|
)
|
||||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
_, probs = model.detect_language(mel_segment)
|
_, probs = model.detect_language(mel_segment, language_bias=language_bias)
|
||||||
decode_options["language"] = max(probs, key=probs.get)
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
print(
|
print(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user