can load 100-language models

This commit is contained in:
Jong Wook Kim 2023-11-06 04:04:24 -08:00
parent 216f2c5fec
commit b7ac4888a2
6 changed files with 43 additions and 22 deletions

View File

@ -25,7 +25,7 @@ def test_transcribe(model_name: str):
assert "your country" in transcription
assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual)
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")

View File

@ -12,7 +12,6 @@ from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -110,7 +109,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):

View File

@ -32,7 +32,7 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
@ -514,7 +514,10 @@ class DecodingTask:
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)

View File

@ -269,7 +269,11 @@ class Whisper(nn.Module):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""

View File

@ -6,6 +6,12 @@ from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
from tiktoken.load import load_tiktoken_bpe
ENCODINGS_BASE = os.environ.get(
"OPENAI_ENCODINGS_BASE",
"az://oaiappliedai/encodings/applied-encodings",
)
LANGUAGES = {
"en": "english",
@ -107,6 +113,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -131,6 +138,7 @@ class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
@ -145,7 +153,7 @@ class Tokenizer:
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
@ -211,10 +219,13 @@ class Tokenizer:
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
if token := self.special_tokens.get(f"<|{self.language}|>", None):
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
@ -222,11 +233,11 @@ class Tokenizer:
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
@ -245,9 +256,7 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@ -269,7 +278,7 @@ class Tokenizer:
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
@ -322,7 +331,7 @@ class Tokenizer:
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2"):
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
@ -334,7 +343,7 @@ def get_encoding(name: str = "gpt2"):
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
@ -361,6 +370,7 @@ def get_encoding(name: str = "gpt2"):
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
@ -381,6 +391,6 @@ def get_tokenizer(
language = None
task = None
encoding = get_encoding(name=encoding_name)
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return Tokenizer(encoding=encoding, language=language, task=task)
return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)

View File

@ -119,7 +119,7 @@ def transcribe(
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None:
@ -140,7 +140,12 @@ def transcribe(
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")