diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index e4f8fd0..599221a 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -25,7 +25,7 @@ def test_transcribe(model_name: str): assert "your country" in transcription assert "do for you" in transcription - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) all_tokens = [t for s in result["segments"] for t in s["tokens"]] assert tokenizer.decode(all_tokens) == result["text"] assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>") diff --git a/whisper/audio.py b/whisper/audio.py index 01d8d15..cf6c66a 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -12,7 +12,6 @@ from .utils import exact_div # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) -def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: +def mel_filters(device, n_mels: int) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -110,7 +109,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], - n_mels: int = N_MELS, + n_mels: int = 80, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): diff --git a/whisper/decoding.py b/whisper/decoding.py index ecd98a4..8316d81 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -32,7 +32,7 @@ def detect_language( list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) if ( tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence @@ -514,7 +514,10 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) diff --git a/whisper/model.py b/whisper/model.py index 6913002..431211e 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -269,7 +269,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 3b23991..ae8d1ad 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -6,6 +6,12 @@ from functools import cached_property, lru_cache from typing import Dict, List, Optional, Tuple import tiktoken +from tiktoken.load import load_tiktoken_bpe + +ENCODINGS_BASE = os.environ.get( + "OPENAI_ENCODINGS_BASE", + "az://oaiappliedai/encodings/applied-encodings", +) LANGUAGES = { "en": "english", @@ -107,6 +113,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -131,6 +138,7 @@ class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding + num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () @@ -145,7 +153,7 @@ class Tokenizer: translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] - langs = tuple(LANGUAGES.keys()) + langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) @@ -211,10 +219,13 @@ class Tokenizer: if self.language is None: raise ValueError("This tokenizer does not have language token configured") - if token := self.special_tokens.get(f"<|{self.language}|>", None): + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: @@ -222,11 +233,11 @@ class Tokenizer: for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: - return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) + return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: @@ -245,9 +256,7 @@ class Tokenizer: keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') - symbols += ( - "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() - ) + symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: @@ -269,7 +278,7 @@ class Tokenizer: return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my"}: + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -322,7 +331,7 @@ class Tokenizer: @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2"): +def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") ranks = { base64.b64decode(token): int(rank) @@ -334,7 +343,7 @@ def get_encoding(name: str = "gpt2"): specials = [ "<|endoftext|>", "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", @@ -361,6 +370,7 @@ def get_encoding(name: str = "gpt2"): def get_tokenizer( multilingual: bool, *, + num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: @@ -381,6 +391,6 @@ def get_tokenizer( language = None task = None - encoding = get_encoding(name=encoding_name) + encoding = get_encoding(name=encoding_name, num_languages=num_languages) - return Tokenizer(encoding=encoding, language=language, task=task) + return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d5b3d43..e80bede 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -119,7 +119,7 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-1] - N_FRAMES if decode_options.get("language", None) is None: @@ -140,7 +140,12 @@ def transcribe( language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) if word_timestamps and task == "translate": warnings.warn("Word-level timestamps on translations may not be reliable.")