drop python 3.7 support (#889)

This commit is contained in:
Jong Wook Kim 2023-01-24 14:05:57 -08:00 committed by GitHub
parent 55f690af79
commit a6b36ede1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 49 deletions

View File

@ -252,11 +252,10 @@ class GreedyDecoder(TokenDecoder):
self.eot = eot self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature if self.temperature == 0:
if temperature == 0:
next_tokens = logits.argmax(dim=-1) next_tokens = logits.argmax(dim=-1)
else: else:
next_tokens = Categorical(logits=logits / temperature).sample() next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1) logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
@ -511,10 +510,8 @@ class DecodingTask:
def _get_initial_tokens(self) -> Tuple[int]: def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence) tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix: if prefix := self.options.prefix:
prefix_tokens = ( prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
) )
@ -523,7 +520,7 @@ class DecodingTask:
prefix_tokens = prefix_tokens[-max_prefix_len:] prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens tokens = tokens + prefix_tokens
if prompt: if prompt := self.options.prompt:
prompt_tokens = ( prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
) )
@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]] result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s) The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
""" """
single = mel.ndim == 2 if single := mel.ndim == 2:
if single:
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel) result = DecodingTask(model, options).run(mel)
if single: return result[0] if single else result
result = result[0]
return result

View File

@ -1,6 +1,6 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache, cached_property
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -156,43 +156,35 @@ class Tokenizer:
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs) return "".join(outputs)
@property @cached_property
@lru_cache()
def eot(self) -> int: def eot(self) -> int:
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@property @cached_property
@lru_cache()
def sot(self) -> int: def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>") return self._get_single_token_id("<|startoftranscript|>")
@property @cached_property
@lru_cache()
def sot_lm(self) -> int: def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>") return self._get_single_token_id("<|startoflm|>")
@property @cached_property
@lru_cache()
def sot_prev(self) -> int: def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>") return self._get_single_token_id("<|startofprev|>")
@property @cached_property
@lru_cache()
def no_speech(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>") return self._get_single_token_id("<|nospeech|>")
@property @cached_property
@lru_cache()
def no_timestamps(self) -> int: def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>") return self._get_single_token_id("<|notimestamps|>")
@property @cached_property
@lru_cache()
def timestamp_begin(self) -> int: def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1 return self.tokenizer.all_special_ids[-1] + 1
@property @cached_property
@lru_cache()
def language_token(self) -> int: def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field""" """Returns the token id corresponding to the value of the `language` field"""
if self.language is None: if self.language is None:
@ -210,8 +202,7 @@ class Tokenizer:
raise KeyError(f"Language {self.language} not found in tokenizer.") raise KeyError(f"Language {self.language} not found in tokenizer.")
@property @cached_property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
result = [] result = []
for token, token_id in zip( for token, token_id in zip(
@ -222,18 +213,15 @@ class Tokenizer:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)
@property @cached_property
@lru_cache()
def all_language_codes(self) -> Tuple[str]: def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property @cached_property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]: def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps]) return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property @cached_property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]: def non_speech_tokens(self) -> Tuple[int]:
""" """
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech

View File

@ -26,6 +26,7 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0, logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
**decode_options, **decode_options,
): ):
""" """
@ -138,10 +139,11 @@ def transcribe(
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or [] if initial_prompt is not None:
if initial_prompt: initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) all_tokens.extend(initial_prompt_tokens)
all_tokens.extend(initial_prompt) else:
initial_prompt_tokens = []
def add_segment( def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
@ -243,7 +245,11 @@ def transcribe(
pbar.update(min(num_frames, seek) - previous_seek_value) pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek previous_seek_value = seek
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language
)
def cli(): def cli():
@ -292,21 +298,18 @@ def cli():
args["language"] = "en" args["language"] = "en"
temperature = args.pop("temperature") temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") if (increment := args.pop("temperature_increment_on_fallback")) is not None:
if temperature_increment_on_fallback is not None: temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else: else:
temperature = [temperature] temperature = [temperature]
threads = args.pop("threads") if (threads := args.pop("threads")) > 0:
if threads > 0:
torch.set_num_threads(threads) torch.set_num_threads(threads)
from . import load_model from . import load_model
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path) writer(result, audio_path)