can load 100-language models

This commit is contained in:
Jong Wook Kim 2023-11-06 04:04:24 -08:00
parent 216f2c5fec
commit b7ac4888a2
6 changed files with 43 additions and 22 deletions

View File

@ -25,7 +25,7 @@ def test_transcribe(model_name: str):
assert "your country" in transcription assert "your country" in transcription
assert "do for you" in transcription assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
all_tokens = [t for s in result["segments"] for t in s["tokens"]] all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"] assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>") assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")

View File

@ -12,7 +12,6 @@ from .utils import exact_div
# hard-coded audio hyperparameters # hard-coded audio hyperparameters
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
N_FFT = 400 N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160 HOP_LENGTH = 160
CHUNK_LENGTH = 30 CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: def mel_filters(device, n_mels: int) -> torch.Tensor:
""" """
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using: Allows decoupling librosa dependency; saved using:
@ -110,7 +109,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS, n_mels: int = 80,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):

View File

@ -32,7 +32,7 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages. list of dictionaries containing the probability distribution over all languages.
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
if ( if (
tokenizer.language is None tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence or tokenizer.language_token not in tokenizer.sot_sequence
@ -514,7 +514,10 @@ class DecodingTask:
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
) )
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options) self.options: DecodingOptions = self._verify_options(options)

View File

@ -269,7 +269,11 @@ class Whisper(nn.Module):
@property @property
def is_multilingual(self): def is_multilingual(self):
return self.dims.n_vocab == 51865 return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None): def install_kv_cache_hooks(self, cache: Optional[dict] = None):
""" """

View File

@ -6,6 +6,12 @@ from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import tiktoken import tiktoken
from tiktoken.load import load_tiktoken_bpe
ENCODINGS_BASE = os.environ.get(
"OPENAI_ENCODINGS_BASE",
"az://oaiappliedai/encodings/applied-encodings",
)
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@ -107,6 +113,7 @@ LANGUAGES = {
"ba": "bashkir", "ba": "bashkir",
"jw": "javanese", "jw": "javanese",
"su": "sundanese", "su": "sundanese",
"yue": "cantonese",
} }
# language code lookup by name, with a few language aliases # language code lookup by name, with a few language aliases
@ -131,6 +138,7 @@ class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens""" """A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None language: Optional[str] = None
task: Optional[str] = None task: Optional[str] = None
sot_sequence: Tuple[int] = () sot_sequence: Tuple[int] = ()
@ -145,7 +153,7 @@ class Tokenizer:
translate: int = self.special_tokens["<|translate|>"] translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"] transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys()) langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot] sot_sequence = [sot]
if self.language is not None: if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language)) sot_sequence.append(sot + 1 + langs.index(self.language))
@ -211,10 +219,13 @@ class Tokenizer:
if self.language is None: if self.language is None:
raise ValueError("This tokenizer does not have language token configured") raise ValueError("This tokenizer does not have language token configured")
if token := self.special_tokens.get(f"<|{self.language}|>", None): return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token return token
raise KeyError(f"Language {self.language} not found in tokenizer.") raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property @cached_property
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
@ -222,11 +233,11 @@ class Tokenizer:
for token, token_id in self.special_tokens.items(): for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES: if token.strip("<|>") in LANGUAGES:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)[: self.num_languages]
@cached_property @cached_property
def all_language_codes(self) -> Tuple[str]: def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@cached_property @cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]: def sot_sequence_including_notimestamps(self) -> Tuple[int]:
@ -245,9 +256,7 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
""" """
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += ( symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer. # symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because: # In case they're multiple tokens, suppress the first token, which is safe because:
@ -269,7 +278,7 @@ class Tokenizer:
return tuple(sorted(result)) return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]): def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}: if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words # These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any # without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points # position where the tokens are decoded as valid unicode points
@ -322,7 +331,7 @@ class Tokenizer:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2"): def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = { ranks = {
base64.b64decode(token): int(rank) base64.b64decode(token): int(rank)
@ -334,7 +343,7 @@ def get_encoding(name: str = "gpt2"):
specials = [ specials = [
"<|endoftext|>", "<|endoftext|>",
"<|startoftranscript|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()], *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>", "<|translate|>",
"<|transcribe|>", "<|transcribe|>",
"<|startoflm|>", "<|startoflm|>",
@ -361,6 +370,7 @@ def get_encoding(name: str = "gpt2"):
def get_tokenizer( def get_tokenizer(
multilingual: bool, multilingual: bool,
*, *,
num_languages: int = 99,
language: Optional[str] = None, language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None] task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer: ) -> Tokenizer:
@ -381,6 +391,6 @@ def get_tokenizer(
language = None language = None
task = None task = None
encoding = get_encoding(name=encoding_name) encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return Tokenizer(encoding=encoding, language=language, task=task) return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)

View File

@ -119,7 +119,7 @@ def transcribe(
decode_options["fp16"] = False decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
@ -140,7 +140,12 @@ def transcribe(
language: str = decode_options["language"] language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if word_timestamps and task == "translate": if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")