word-level timestamps in transcribe() (#869)

* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <ryan@ryanheise.com>
This commit is contained in:
Jong Wook Kim 2023-03-06 17:00:49 -05:00 committed by GitHub
parent eab8d920ed
commit 500d0fe966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 768 additions and 77 deletions

View File

@ -21,6 +21,5 @@ jobs:
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest - run: pip install .["dev"]
- run: pip install . - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'

View File

@ -1,3 +1,4 @@
numba
numpy numpy
torch torch
tqdm tqdm

View File

@ -1,4 +1,5 @@
import os import os
import sys
import pkg_resources import pkg_resources
from setuptools import setup, find_packages from setuptools import setup, find_packages
@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
return locals()["__version__"] return locals()["__version__"]
requirements = []
if sys.platform.startswith("linux"):
triton_requirement = "triton>=2.0.0.dev20221202"
try:
import re
import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)
setup( setup(
name="openai-whisper", name="openai-whisper",
py_modules=["whisper"], py_modules=["whisper"],
@ -22,7 +38,7 @@ setup(
url="https://github.com/openai/whisper", url="https://github.com/openai/whisper",
license="MIT", license="MIT",
packages=find_packages(exclude=["tests*"]), packages=find_packages(exclude=["tests*"]),
install_requires=[ install_requires=requirements + [
str(r) str(r)
for r in pkg_resources.parse_requirements( for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@ -32,5 +48,5 @@ setup(
"console_scripts": ["whisper=whisper.transcribe:cli"], "console_scripts": ["whisper=whisper.transcribe:cli"],
}, },
include_package_data=True, include_package_data=True,
extras_require={"dev": ["pytest"]}, extras_require={"dev": ["pytest", "scipy"]},
) )

14
tests/conftest.py Normal file
View File

@ -0,0 +1,14 @@
import random as rand
import numpy
import pytest
def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")
@pytest.fixture
def random():
rand.seed(42)
numpy.random.seed(42)

87
tests/test_timing.py Normal file
View File

@ -0,0 +1,87 @@
import pytest
import numpy as np
import scipy.ndimage
import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189),
]
shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
]
@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
np.random.shuffle(steps)
x = np.random.random((N, M)).astype(np.float32)
i, j, k = 0, 0, 0
trace = []
while True:
x[i, j] -= 1
trace.append((i, j))
if k == len(steps):
break
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
i += 1
j += 1
k += 2
continue
if steps[k] == 0:
i += 1
if steps[k] == 1:
j += 1
k += 1
trace = np.array(trace).T
dtw_trace = dtw_cpu(x)
assert np.allclose(trace, dtw_trace)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_cuda_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_cuda = torch.from_numpy(x_numpy).cuda()
trace_cpu = dtw_cpu(x_numpy)
trace_cuda = dtw_cuda(x_cuda)
assert np.allclose(trace_cpu, trace_cuda)
@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
assert np.allclose(filtered, scipy_filtered)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("shape", shapes)
def test_median_filter_equivalence(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered_cpu = median_filter(x, filter_width)
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
assert np.allclose(filtered_cpu, filtered_gpu)

View File

@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0) result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
assert result["language"] == "en" assert result["language"] == "en"
transcription = result["text"].lower() transcription = result["text"].lower()
assert "my fellow americans" in transcription assert "my fellow americans" in transcription
assert "your country" in transcription assert "your country" in transcription
assert "do for you" in transcription assert "do for you" in transcription
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True
assert timing_checked

View File

@ -29,6 +29,23 @@ _MODELS = {
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
} }
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if name in _MODELS: if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory) checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name): elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else: else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims) model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device) return model.to(device)

View File

@ -18,6 +18,10 @@ CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """

View File

@ -1,3 +1,5 @@
import base64
import gzip
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict
from typing import Iterable, Optional from typing import Iterable, Optional
@ -8,8 +10,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function from .decoding import detect_language as detect_language_function, decode as decode_function
from .transcribe import transcribe as transcribe_function
@dataclass @dataclass
@ -213,6 +215,15 @@ class Whisper(nn.Module):
self.dims.n_text_head, self.dims.n_text_head,
self.dims.n_text_layer, self.dims.n_text_layer,
) )
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor): def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel) return self.encoder(mel)

305
whisper/timing.py Normal file
View File

@ -0,0 +1,305 @@
import subprocess
import warnings
from dataclasses import dataclass
from typing import List, TYPE_CHECKING
import numba
import numpy as np
import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
# F.pad requires the padding width to be smaller than the input dimension
return x
if (ndim := x.ndim) <= 2:
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]
return result
@numba.jit
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")
result = np.array(result)
return result[::-1, :].T
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE
)
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1
def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
**hyperparams,
):
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
for segment in segments:
segment["words"] = []
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word:
segment = segments[token_sources[word_boundaries[i]]]
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability)
)
for segment in segments:
if len(words := segment["words"]) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]

View File

@ -1,4 +1,5 @@
import os import os
import string
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, cached_property from functools import lru_cache, cached_property
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -265,6 +266,48 @@ class Tokenizer:
assert len(tokens) == 1, f"{text} is not encoded as a single token" assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0] return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
words = []
word_tokens = []
current_tokens = []
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"): def build_tokenizer(name: str = "gpt2"):

View File

@ -7,8 +7,9 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
@ -27,6 +28,9 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
**decode_options, **decode_options,
): ):
""" """
@ -63,6 +67,21 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
@ -90,16 +109,19 @@ def transcribe(
else: else:
if verbose: if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment) _, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
language = decode_options["language"] language: str = decode_options["language"]
task = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
decode_result = None decode_result = None
@ -145,42 +167,35 @@ def transcribe(
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
def add_segment( def new_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
): ):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
if len(text.strip()) == 0: # skip empty text output return {
return
all_segments.append(
{
"id": len(all_segments), "id": len(all_segments),
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": text, "text": tokenizer.decode(text_tokens),
"tokens": text_tokens.tolist(), "tokens": text_tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob, "no_speech_prob": result.no_speech_prob,
} }
)
if verbose:
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
# show the progress bar when verbose is False (otherwise the transcribed text will be printed) # show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1] num_frames = mel.shape[-1]
previous_seek_value = seek
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames: while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) mel_segment = mel[:, seek:]
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE segment_size = min(mel_segment.shape[-1], N_FRAMES)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None: if no_speech_threshold is not None:
@ -191,29 +206,36 @@ def transcribe(
should_skip = False should_skip = False
if should_skip: if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary seek += segment_size # fast-forward to the next segment boundary
continue continue
previous_seek = seek
current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
consecutive = consecutive.tolist() + [len(tokens)] consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0 last_slice = 0
for current_slice in consecutive: for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
add_segment( current_segments.append(new_segment(
start=timestamp_offset + start_timestamp_pos * time_precision, start=time_offset + start_timestamp_pos * time_precision,
end=timestamp_offset + end_timestamp_pos * time_precision, end=time_offset + end_timestamp_pos * time_precision,
text_tokens=sliced_tokens[1:-1], tokens=sliced_tokens,
result=result, result=result,
) ))
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice last_slice = current_slice
if ended_with_single_timestamp: if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
seek += segment.shape[-1] seek += segment_size
else: else:
# otherwise, ignore the unfinished segment and seek to the last timestamp # otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
@ -227,23 +249,54 @@ def transcribe(
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_pos * time_precision duration = last_timestamp_pos * time_precision
add_segment( current_segments.append(new_segment(
start=timestamp_offset, start=time_offset,
end=timestamp_offset + duration, end=time_offset + duration,
text_tokens=tokens, tokens=tokens,
result=result, result=result,
) ))
current_tokens.append(tokens.tolist())
seek += segment.shape[-1] seek += segment_size
all_tokens.extend(tokens.tolist())
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
if seek_shift > 0:
seek = previous_seek + seek_shift
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments)
all_tokens.extend([token for segment in current_tokens for token in segment])
# update progress bar # update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value) pbar.update(min(num_frames, seek) - previous_seek)
previous_seek_value = seek
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
@ -282,6 +335,9 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__

92
whisper/triton_ops.py Normal file
View File

@ -0,0 +1,92 @@
import math
import numpy as np
import torch
from functools import lru_cache
try:
import triton
import triton.language as tl
except ImportError:
raise RuntimeError("triton import failed; try `pip install --pre triton`")
@triton.jit
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M
for k in range(1, N + M + 1): # k = i + j
tl.debug_barrier()
p0 = cost + (k - 1) * cost_stride
p1 = cost + k * cost_stride
p2 = cost + k * cost_stride + 1
c0 = tl.load(p0 + offsets, mask=mask)
c1 = tl.load(p1 + offsets, mask=mask)
c2 = tl.load(p2 + offsets, mask=mask)
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
cost_ptr = cost + (k + 1) * cost_stride + 1
tl.store(cost_ptr + offsets, cost_row, mask=mask)
trace_ptr = trace + (k + 1) * trace_stride + 1
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
@triton.jit
def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride
x_ptr = x + row_idx * x_stride
y_ptr = y + row_idx * y_stride
LOAD_ALL_ROWS_HERE
BUBBLESORT_HERE
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]))
kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([
"\n\n".join([
"\n".join([
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
])
for j in range(filter_width - i - 1)
])
for i in range(filter_width // 2 + 1)
]))
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
return kernel
def median_filter_cuda(x: torch.Tensor, filter_width: int):
"""Apply a median filter of given width along the last dimension of x"""
slices = x.contiguous().unfold(-1, filter_width, 1)
grid = np.prod(slices.shape[:-2])
kernel = median_kernel(filter_width)
y = torch.empty_like(slices[..., 0])
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
return y

View File

@ -85,34 +85,63 @@ class WriteTXT(ResultWriter):
print(segment['text'].strip(), file=file, flush=True) print(segment['text'].strip(), file=file, flush=True)
class WriteVTT(ResultWriter): class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment['text'].strip().replace('-->', '->')
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
class WriteVTT(SubtitlesWriter):
extension: str = "vtt" extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = '.'
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for segment in result["segments"]: for start, end, text in self.iterate_result(result):
print( print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteSRT(ResultWriter): class WriteSRT(SubtitlesWriter):
extension: str = "srt" extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ','
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1): for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
# write srt lines print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteTSV(ResultWriter): class WriteTSV(ResultWriter):