mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 15:54:00 +00:00
drop python 3.7 support (#889)
This commit is contained in:
parent
55f690af79
commit
a6b36ede1f
@ -252,11 +252,10 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
self.eot = eot
|
self.eot = eot
|
||||||
|
|
||||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||||
temperature = self.temperature
|
if self.temperature == 0:
|
||||||
if temperature == 0:
|
|
||||||
next_tokens = logits.argmax(dim=-1)
|
next_tokens = logits.argmax(dim=-1)
|
||||||
else:
|
else:
|
||||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||||
|
|
||||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
@ -511,10 +510,8 @@ class DecodingTask:
|
|||||||
|
|
||||||
def _get_initial_tokens(self) -> Tuple[int]:
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
tokens = list(self.sot_sequence)
|
tokens = list(self.sot_sequence)
|
||||||
prefix = self.options.prefix
|
|
||||||
prompt = self.options.prompt
|
|
||||||
|
|
||||||
if prefix:
|
if prefix := self.options.prefix:
|
||||||
prefix_tokens = (
|
prefix_tokens = (
|
||||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||||
)
|
)
|
||||||
@ -523,7 +520,7 @@ class DecodingTask:
|
|||||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||||
tokens = tokens + prefix_tokens
|
tokens = tokens + prefix_tokens
|
||||||
|
|
||||||
if prompt:
|
if prompt := self.options.prompt:
|
||||||
prompt_tokens = (
|
prompt_tokens = (
|
||||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||||
)
|
)
|
||||||
@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
|
|||||||
result: Union[DecodingResult, List[DecodingResult]]
|
result: Union[DecodingResult, List[DecodingResult]]
|
||||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||||
"""
|
"""
|
||||||
single = mel.ndim == 2
|
if single := mel.ndim == 2:
|
||||||
if single:
|
|
||||||
mel = mel.unsqueeze(0)
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
result = DecodingTask(model, options).run(mel)
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
if single:
|
return result[0] if single else result
|
||||||
result = result[0]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache, cached_property
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -156,43 +156,35 @@ class Tokenizer:
|
|||||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
return "".join(outputs)
|
return "".join(outputs)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def eot(self) -> int:
|
def eot(self) -> int:
|
||||||
return self.tokenizer.eos_token_id
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def sot(self) -> int:
|
def sot(self) -> int:
|
||||||
return self._get_single_token_id("<|startoftranscript|>")
|
return self._get_single_token_id("<|startoftranscript|>")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def sot_lm(self) -> int:
|
def sot_lm(self) -> int:
|
||||||
return self._get_single_token_id("<|startoflm|>")
|
return self._get_single_token_id("<|startoflm|>")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def sot_prev(self) -> int:
|
def sot_prev(self) -> int:
|
||||||
return self._get_single_token_id("<|startofprev|>")
|
return self._get_single_token_id("<|startofprev|>")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def no_speech(self) -> int:
|
def no_speech(self) -> int:
|
||||||
return self._get_single_token_id("<|nospeech|>")
|
return self._get_single_token_id("<|nospeech|>")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def no_timestamps(self) -> int:
|
def no_timestamps(self) -> int:
|
||||||
return self._get_single_token_id("<|notimestamps|>")
|
return self._get_single_token_id("<|notimestamps|>")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def timestamp_begin(self) -> int:
|
def timestamp_begin(self) -> int:
|
||||||
return self.tokenizer.all_special_ids[-1] + 1
|
return self.tokenizer.all_special_ids[-1] + 1
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def language_token(self) -> int:
|
def language_token(self) -> int:
|
||||||
"""Returns the token id corresponding to the value of the `language` field"""
|
"""Returns the token id corresponding to the value of the `language` field"""
|
||||||
if self.language is None:
|
if self.language is None:
|
||||||
@ -210,8 +202,7 @@ class Tokenizer:
|
|||||||
|
|
||||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def all_language_tokens(self) -> Tuple[int]:
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
result = []
|
result = []
|
||||||
for token, token_id in zip(
|
for token, token_id in zip(
|
||||||
@ -222,18 +213,15 @@ class Tokenizer:
|
|||||||
result.append(token_id)
|
result.append(token_id)
|
||||||
return tuple(result)
|
return tuple(result)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def all_language_codes(self) -> Tuple[str]:
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@lru_cache()
|
|
||||||
def non_speech_tokens(self) -> Tuple[int]:
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
|||||||
@ -26,6 +26,7 @@ def transcribe(
|
|||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
|
initial_prompt: Optional[str] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -138,10 +139,11 @@ def transcribe(
|
|||||||
all_segments = []
|
all_segments = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
if initial_prompt is not None:
|
||||||
if initial_prompt:
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
all_tokens.extend(initial_prompt)
|
else:
|
||||||
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
def add_segment(
|
def add_segment(
|
||||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||||
@ -243,7 +245,11 @@ def transcribe(
|
|||||||
pbar.update(min(num_frames, seek) - previous_seek_value)
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||||
previous_seek_value = seek
|
previous_seek_value = seek
|
||||||
|
|
||||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
return dict(
|
||||||
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
|
||||||
|
segments=all_segments,
|
||||||
|
language=language
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@ -292,21 +298,18 @@ def cli():
|
|||||||
args["language"] = "en"
|
args["language"] = "en"
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
temperature = args.pop("temperature")
|
||||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||||
if temperature_increment_on_fallback is not None:
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
|
||||||
else:
|
else:
|
||||||
temperature = [temperature]
|
temperature = [temperature]
|
||||||
|
|
||||||
threads = args.pop("threads")
|
if (threads := args.pop("threads")) > 0:
|
||||||
if threads > 0:
|
|
||||||
torch.set_num_threads(threads)
|
torch.set_num_threads(threads)
|
||||||
|
|
||||||
from . import load_model
|
from . import load_model
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
writer(result, audio_path)
|
writer(result, audio_path)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user