apply formatting with black (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
This commit is contained in:
Jong Wook Kim 2023-03-06 18:50:37 -05:00 committed by GitHub
parent 500d0fe966
commit b80bcf610d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 533 additions and 227 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
per-file-ignores =
*/__init__.py: F401

View File

@ -22,4 +22,7 @@ jobs:
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"]
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
- run: isort --check --diff .
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

8
pyproject.toml Normal file
View File

@ -0,0 +1,8 @@
[tool.black]
[tool.isort]
profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3

View File

@ -2,7 +2,7 @@ import os
import sys
import pkg_resources
from setuptools import setup, find_packages
from setuptools import find_packages, setup
def read_version(fname="whisper/version.py"):
@ -16,7 +16,10 @@ if sys.platform.startswith("linux"):
try:
import re
import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
@ -38,7 +41,8 @@ setup(
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=requirements + [
install_requires=requirements
+ [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@ -48,5 +52,5 @@ setup(
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest", "scipy"]},
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
)

View File

@ -2,7 +2,7 @@ import os.path
import numpy as np
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio():

View File

@ -1,7 +1,10 @@
import pytest
from whisper.normalizers import EnglishTextNormalizer
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
from whisper.normalizers.english import (
EnglishNumberNormalizer,
EnglishSpellingNormalizer,
)
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])

View File

@ -1,16 +1,21 @@
import pytest
import numpy as np
import pytest
import scipy.ndimage
import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189),
(10, 20),
(32, 16),
(123, 1500),
(234, 189),
]
shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
(10,),
(1, 15),
(4, 5, 345),
(6, 12, 240, 512),
]
@ -68,8 +73,12 @@ def test_median_filter(shape):
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
padded_x = np.pad(
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
)
scipy_filtered = scipy.ndimage.median_filter(
padded_x, [1] * (x.ndim - 1) + [filter_width]
)
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
assert np.allclose(filtered, scipy_filtered)

View File

@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
transcription = result["text"].lower()

View File

@ -10,11 +10,10 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@ -41,12 +40,11 @@ _ALIGNMENT_HEADS = {
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
@ -86,7 +94,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(
os.getenv(
"XDG_CACHE_HOME",
os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
)
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

View File

@ -1,4 +1,3 @@
from .transcribe import cli
cli()

View File

@ -16,11 +16,13 @@ N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
):
"""
Compute the log-Mel spectrogram of

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -16,7 +16,9 @@ if TYPE_CHECKING:
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@ -158,7 +170,9 @@ class PyTorchInference(Inference):
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
@ -196,7 +210,9 @@ class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
@ -251,7 +267,9 @@ class GreedyDecoder(TokenDecoder):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
@ -274,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
@ -282,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
@ -331,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
@ -339,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
@ -347,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@ -355,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
@ -399,7 +433,10 @@ class SuppressTokens(LogitFilter):
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
@ -414,8 +451,12 @@ class ApplyTimestampRules(LogitFilter):
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
@ -423,7 +464,9 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
@ -434,13 +477,17 @@ class ApplyTimestampRules(LogitFilter):
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@ -456,7 +503,9 @@ class DecodingTask:
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@ -496,9 +545,13 @@ class DecodingTask:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@ -509,7 +562,9 @@ class DecodingTask:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
@ -519,7 +574,9 @@ class DecodingTask:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
@ -528,9 +585,15 @@ class DecodingTask:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
@ -554,7 +617,7 @@ class DecodingTask:
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
@ -567,14 +630,21 @@ class DecodingTask:
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
@ -583,7 +653,9 @@ class DecodingTask:
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
@ -600,7 +672,9 @@ class DecodingTask:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
@ -634,8 +708,12 @@ class DecodingTask:
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
@ -656,7 +734,8 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
]
# select the top-ranked sample in each group
@ -665,9 +744,18 @@ class DecodingTask:
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
@ -682,12 +770,16 @@ class DecodingTask:
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).

View File

@ -1,16 +1,15 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch import Tensor, nn
from .decoding import detect_language as detect_language_function, decode as decode_function
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
@ -36,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
@ -87,7 +90,9 @@ class MultiHeadAttention(nn.Module):
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@ -110,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
@ -132,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@ -163,14 +174,19 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
@ -185,14 +201,19 @@ class TextDecoder(nn.Module):
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
@ -216,13 +237,19 @@ class Whisper(nn.Module):
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
@ -231,7 +258,9 @@ class Whisper(nn.Module):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
@ -260,8 +289,9 @@ class Whisper(nn.Module):
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]

View File

@ -1,2 +1,2 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer

View File

@ -48,13 +48,16 @@ def remove_symbols(s: str):
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
@ -66,6 +69,8 @@ class BasicTextNormalizer:
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s

View File

@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
@ -522,14 +529,14 @@ class EnglishTextNormalizer:
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
@ -538,6 +545,6 @@ class EnglishTextNormalizer:
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s

View File

@ -1,7 +1,7 @@
import subprocess
import warnings
from dataclasses import dataclass
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING, List
import numba
import numpy as np
@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
return result
@numba.jit
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
@ -181,8 +189,10 @@ def find_alignment(
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
@ -196,7 +206,7 @@ def find_alignment(
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1]
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
@ -207,7 +217,8 @@ def find_alignment(
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
@ -218,7 +229,8 @@ def find_alignment(
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)
@ -271,19 +283,20 @@ def add_word_timestamps(
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
**hyperparams,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**kwargs,
):
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
segment_lengths = [len(s["tokens"]) for s in segments]
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments:
segment["words"] = []
@ -295,7 +308,12 @@ def add_word_timestamps(
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability)
dict(
word=timing.word,
start=start,
end=end,
probability=timing.probability,
)
)
for segment in segments:

View File

@ -1,7 +1,7 @@
import os
import string
from dataclasses import dataclass
from functools import lru_cache, cached_property
from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
@ -138,7 +138,9 @@ class Tokenizer:
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
def decode(
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
@ -154,8 +156,9 @@ class Tokenizer:
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property
def eot(self) -> int:
@ -197,7 +200,7 @@ class Tokenizer:
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
@ -242,8 +245,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@ -255,7 +260,10 @@ class Tokenizer:
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
for tokens in [
self.tokenizer.encode(symbol),
self.tokenizer.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
@ -367,4 +375,6 @@ def get_tokenizer(
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)

View File

@ -1,17 +1,32 @@
import argparse
import os
import warnings
from typing import Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
from .utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from .model import Whisper
@ -29,8 +44,8 @@ def transcribe(
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
@ -108,12 +123,16 @@ def transcribe(
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
@ -123,7 +142,9 @@ def transcribe(
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
@ -140,9 +161,15 @@ def transcribe(
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
@ -186,7 +213,9 @@ def transcribe(
# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
@ -201,7 +230,10 @@ def transcribe(
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
@ -214,22 +246,35 @@ def transcribe(
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
current_segments.append(new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
))
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice
@ -238,23 +283,32 @@ def transcribe(
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
))
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
current_tokens.append(tokens.tolist())
seek += segment_size
@ -272,9 +326,13 @@ def transcribe(
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
@ -293,21 +351,24 @@ def transcribe(
current_tokens[i] = []
all_segments.extend(current_segments)
all_tokens.extend([token for segment in current_tokens for token in segment])
all_tokens.extend(
[token for segment in current_tokens for token in segment]
)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@ -339,6 +400,7 @@ def cli():
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
@ -350,7 +412,9 @@ def cli():
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
@ -363,6 +427,7 @@ def cli():
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
@ -371,5 +436,5 @@ def cli():
writer(result, audio_path)
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View File

@ -1,8 +1,7 @@
import math
from functools import lru_cache
import numpy as np
import torch
from functools import lru_cache
try:
import triton
@ -12,7 +11,9 @@ except ImportError:
@triton.jit
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
def dtw_kernel(
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M
@ -42,37 +43,53 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
@triton.jit
def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width
def kernel(
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
): # x.shape[-1] == filter_width
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride
x_ptr = x + row_idx * x_stride
x_ptr = x + row_idx * x_stride # noqa: F841
y_ptr = y + row_idx * y_stride
LOAD_ALL_ROWS_HERE
LOAD_ALL_ROWS_HERE # noqa: F821
BUBBLESORT_HERE
BUBBLESORT_HERE # noqa: F821
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]))
kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([
"\n\n".join([
"\n".join([
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
])
for j in range(filter_width - i - 1)
])
for i in range(filter_width // 2 + 1)
]))
kernel.src = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]
),
)
kernel.src = kernel.src.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
"\n\n".join(
[
"\n".join(
[
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
]
)
for j in range(filter_width - i - 1)
]
)
for i in range(filter_width // 2 + 1)
]
),
)
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
return kernel

View File

@ -7,11 +7,14 @@ from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
@ -68,7 +75,9 @@ class ResultWriter:
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
@ -82,7 +91,7 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)
print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter):
@ -93,7 +102,7 @@ class SubtitlesWriter(ResultWriter):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment['text'].strip().replace('-->', '->')
segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
@ -106,7 +115,10 @@ class SubtitlesWriter(ResultWriter):
yield last, start, segment_text
yield start, end, "".join(
[f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
@ -126,7 +138,7 @@ class SubtitlesWriter(ResultWriter):
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = '.'
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
@ -137,7 +149,7 @@ class WriteVTT(SubtitlesWriter):
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ','
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
@ -153,14 +165,15 @@ class WriteTSV(ResultWriter):
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
@ -189,4 +202,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
return write_all
return writers[output_format](output_dir)