mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
can load 100-language models
This commit is contained in:
parent
216f2c5fec
commit
b7ac4888a2
@ -25,7 +25,7 @@ def test_transcribe(model_name: str):
|
|||||||
assert "your country" in transcription
|
assert "your country" in transcription
|
||||||
assert "do for you" in transcription
|
assert "do for you" in transcription
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
||||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||||
assert tokenizer.decode(all_tokens) == result["text"]
|
assert tokenizer.decode(all_tokens) == result["text"]
|
||||||
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from .utils import exact_div
|
|||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
N_FFT = 400
|
N_FFT = 400
|
||||||
N_MELS = 80
|
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
Allows decoupling librosa dependency; saved using:
|
Allows decoupling librosa dependency; saved using:
|
||||||
@ -110,7 +109,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
|||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
n_mels: int = N_MELS,
|
n_mels: int = 80,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def detect_language(
|
|||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
@ -514,7 +514,10 @@ class DecodingTask:
|
|||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual, language=language, task=options.task
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
|
|||||||
@ -269,7 +269,11 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multilingual(self):
|
def is_multilingual(self):
|
||||||
return self.dims.n_vocab == 51865
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self):
|
||||||
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,6 +6,12 @@ from functools import cached_property, lru_cache
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
ENCODINGS_BASE = os.environ.get(
|
||||||
|
"OPENAI_ENCODINGS_BASE",
|
||||||
|
"az://oaiappliedai/encodings/applied-encodings",
|
||||||
|
)
|
||||||
|
|
||||||
LANGUAGES = {
|
LANGUAGES = {
|
||||||
"en": "english",
|
"en": "english",
|
||||||
@ -107,6 +113,7 @@ LANGUAGES = {
|
|||||||
"ba": "bashkir",
|
"ba": "bashkir",
|
||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
@ -131,6 +138,7 @@ class Tokenizer:
|
|||||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||||
|
|
||||||
encoding: tiktoken.Encoding
|
encoding: tiktoken.Encoding
|
||||||
|
num_languages: int
|
||||||
language: Optional[str] = None
|
language: Optional[str] = None
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
sot_sequence: Tuple[int] = ()
|
sot_sequence: Tuple[int] = ()
|
||||||
@ -145,7 +153,7 @@ class Tokenizer:
|
|||||||
translate: int = self.special_tokens["<|translate|>"]
|
translate: int = self.special_tokens["<|translate|>"]
|
||||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
langs = tuple(LANGUAGES.keys())
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||||
sot_sequence = [sot]
|
sot_sequence = [sot]
|
||||||
if self.language is not None:
|
if self.language is not None:
|
||||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
@ -211,10 +219,13 @@ class Tokenizer:
|
|||||||
if self.language is None:
|
if self.language is None:
|
||||||
raise ValueError("This tokenizer does not have language token configured")
|
raise ValueError("This tokenizer does not have language token configured")
|
||||||
|
|
||||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
return self.to_language_token(self.language)
|
||||||
|
|
||||||
|
def to_language_token(self, language):
|
||||||
|
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||||
return token
|
return token
|
||||||
|
|
||||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_tokens(self) -> Tuple[int]:
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
@ -222,11 +233,11 @@ class Tokenizer:
|
|||||||
for token, token_id in self.special_tokens.items():
|
for token, token_id in self.special_tokens.items():
|
||||||
if token.strip("<|>") in LANGUAGES:
|
if token.strip("<|>") in LANGUAGES:
|
||||||
result.append(token_id)
|
result.append(token_id)
|
||||||
return tuple(result)
|
return tuple(result)[: self.num_languages]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_codes(self) -> Tuple[str]:
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||||
@ -245,9 +256,7 @@ class Tokenizer:
|
|||||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
"""
|
"""
|
||||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||||
symbols += (
|
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
|
||||||
)
|
|
||||||
|
|
||||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
@ -269,7 +278,7 @@ class Tokenizer:
|
|||||||
return tuple(sorted(result))
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def split_to_word_tokens(self, tokens: List[int]):
|
def split_to_word_tokens(self, tokens: List[int]):
|
||||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||||
# These languages don't typically use spaces, so it is difficult to split words
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
# without morpheme analysis. Here, we instead split words at any
|
# without morpheme analysis. Here, we instead split words at any
|
||||||
# position where the tokens are decoded as valid unicode points
|
# position where the tokens are decoded as valid unicode points
|
||||||
@ -322,7 +331,7 @@ class Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_encoding(name: str = "gpt2"):
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||||
ranks = {
|
ranks = {
|
||||||
base64.b64decode(token): int(rank)
|
base64.b64decode(token): int(rank)
|
||||||
@ -334,7 +343,7 @@ def get_encoding(name: str = "gpt2"):
|
|||||||
specials = [
|
specials = [
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|startoftranscript|>",
|
"<|startoftranscript|>",
|
||||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||||
"<|translate|>",
|
"<|translate|>",
|
||||||
"<|transcribe|>",
|
"<|transcribe|>",
|
||||||
"<|startoflm|>",
|
"<|startoflm|>",
|
||||||
@ -361,6 +370,7 @@ def get_encoding(name: str = "gpt2"):
|
|||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
multilingual: bool,
|
multilingual: bool,
|
||||||
*,
|
*,
|
||||||
|
num_languages: int = 99,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
) -> Tokenizer:
|
) -> Tokenizer:
|
||||||
@ -381,6 +391,6 @@ def get_tokenizer(
|
|||||||
language = None
|
language = None
|
||||||
task = None
|
task = None
|
||||||
|
|
||||||
encoding = get_encoding(name=encoding_name)
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||||
|
|
||||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
|
||||||
|
|||||||
@ -119,7 +119,7 @@ def transcribe(
|
|||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
@ -140,7 +140,12 @@ def transcribe(
|
|||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user