mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Use tiktoken (#1044)
* use tiktoken==0.3.0 * formatting * tuple should be safer * Update whisper/tokenizer.py Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com> * use tiktoken 0.3.1 * reflecting suggestions * cleanup * bypassing load_tiktoken_bpe to avoid blobfile dep --------- Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
This commit is contained in:
parent
ad3250a846
commit
839639a223
@ -2,6 +2,4 @@ include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include whisper/assets/*
|
||||
include whisper/assets/gpt2/*
|
||||
include whisper/assets/multilingual/*
|
||||
include whisper/normalizers/english.json
|
||||
|
@ -3,5 +3,5 @@ numpy
|
||||
torch
|
||||
tqdm
|
||||
more-itertools
|
||||
transformers>=4.19.0
|
||||
tiktoken==0.3.1
|
||||
ffmpeg-python==0.2.0
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
@ -24,6 +25,11 @@ def test_transcribe(model_name: str):
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||
assert tokenizer.decode(all_tokens) == result["text"]
|
||||
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
@ -31,7 +37,6 @@ def test_transcribe(model_name: str):
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
print(timing)
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
50256
whisper/assets/gpt2.tiktoken
Normal file
50256
whisper/assets/gpt2.tiktoken
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
@ -1 +0,0 @@
|
||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
50257
whisper/assets/multilingual.tiktoken
Normal file
50257
whisper/assets/multilingual.tiktoken
Normal file
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
{"<|endoftext|>": 50257}
|
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
@ -1 +0,0 @@
|
||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
File diff suppressed because one or more lines are too long
@ -1,12 +1,12 @@
|
||||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import GPT2TokenizerFast
|
||||
import tiktoken
|
||||
from tiktoken_ext.openai_public import gpt2
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
tokenizer: "GPT2TokenizerFast"
|
||||
language: Optional[str]
|
||||
sot_sequence: Tuple[int]
|
||||
encoding: tiktoken.Encoding
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(
|
||||
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
|
||||
):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
outputs = [[]]
|
||||
for token in tokens:
|
||||
if token >= self.timestamp_begin:
|
||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
return "".join(
|
||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
)
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
return self.encoding.eot_token
|
||||
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self._get_single_token_id("<|transcribe|>")
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self._get_single_token_id("<|translate|>")
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
@ -202,25 +212,15 @@ class Tokenizer:
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
)
|
||||
)
|
||||
candidate = f"<|{self.language}|>"
|
||||
if candidate in additional_tokens:
|
||||
return additional_tokens[candidate]
|
||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
):
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
@ -258,22 +258,17 @@ class Tokenizer:
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.tokenizer.encode(symbol),
|
||||
self.tokenizer.encode(" " + symbol),
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def _get_single_token_id(self, text) -> int:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
@ -318,12 +313,17 @@ class Tokenizer:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||
def get_encoding(name: str = "gpt2"):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
@ -332,18 +332,28 @@ def build_tokenizer(name: str = "gpt2"):
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||
return tokenizer
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=gpt2()["pat_str"],
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
@ -354,27 +364,14 @@ def get_tokenizer(
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
tokenizer_name = "multilingual"
|
||||
task = task or "transcribe"
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
tokenizer_name = "gpt2"
|
||||
task = None
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||
sot: int = all_special_ids[1]
|
||||
translate: int = all_special_ids[-6]
|
||||
transcribe: int = all_special_ids[-5]
|
||||
encoding = get_encoding(name=encoding_name)
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(language))
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(
|
||||
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
||||
)
|
||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
||||
|
Loading…
x
Reference in New Issue
Block a user