mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Fix truncated words list when the replacement character is decoded (#1089)
This commit is contained in:
parent
ba88b8e1b3
commit
5f9ac653b7
@ -12,3 +12,13 @@ def test_tokenizer():
|
|||||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_on_unicode():
|
||||||
|
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||||
|
|
||||||
|
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
||||||
|
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
||||||
|
|
||||||
|
assert words == [" elle", " est", " l", "'", "<EFBFBD>", "é", "rit", "oire"]
|
||||||
|
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
||||||
|
|||||||
@ -279,17 +279,27 @@ class Tokenizer:
|
|||||||
return self.split_tokens_on_spaces(tokens)
|
return self.split_tokens_on_spaces(tokens)
|
||||||
|
|
||||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||||
|
decoded_full = self.decode_with_timestamps(tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
words = []
|
words = []
|
||||||
word_tokens = []
|
word_tokens = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
unicode_offset = 0
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
if "\ufffd" not in decoded:
|
|
||||||
|
if (
|
||||||
|
replacement_char not in decoded
|
||||||
|
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||||
|
== replacement_char
|
||||||
|
):
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
return words, word_tokens
|
return words, word_tokens
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user