mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
word-level timestamps in transcribe() (#869)
* word-level timestamps in `transcribe()` * moving to `timing.py` * numba implementation for dtw, replacing dtw-python * triton implementation for dtw * add test for dtw implementations * triton implementation of median_filter * a simple word-level timestamps test * add scipy as dev dependency * installs an older version of Triton if CUDA < 11.4 * fix broken merge * loosen nvcc version match regex * find_alignment() function * miscellaneous improvements * skip median filtering when the input is too small * Expose punctuation options in cli and transcribe() (#973) * fix merge error * fix merge error 2 * annotating that word_timestamps is experimental --------- Co-authored-by: ryanheise <ryan@ryanheise.com>
This commit is contained in:
parent
eab8d920ed
commit
500d0fe966
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -21,6 +21,5 @@ jobs:
|
||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||
- run: pip install pytest
|
||||
- run: pip install .
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
|
||||
- run: pip install .["dev"]
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
tqdm
|
||||
|
||||
20
setup.py
20
setup.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux"):
|
||||
triton_requirement = "triton>=2.0.0.dev20221202"
|
||||
try:
|
||||
import re
|
||||
import subprocess
|
||||
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
|
||||
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
|
||||
if (int(major), int(minor)) < (11, 4):
|
||||
# the last version supporting CUDA < 11.4
|
||||
triton_requirement = "triton==2.0.0.dev20221011"
|
||||
except (IndexError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
requirements.append(triton_requirement)
|
||||
|
||||
setup(
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
@ -22,7 +38,7 @@ setup(
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
install_requires=requirements + [
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
@ -32,5 +48,5 @@ setup(
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={"dev": ["pytest"]},
|
||||
extras_require={"dev": ["pytest", "scipy"]},
|
||||
)
|
||||
|
||||
14
tests/conftest.py
Normal file
14
tests/conftest.py
Normal file
@ -0,0 +1,14 @@
|
||||
import random as rand
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "requires_cuda")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
rand.seed(42)
|
||||
numpy.random.seed(42)
|
||||
87
tests/test_timing.py
Normal file
87
tests/test_timing.py
Normal file
@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
||||
|
||||
|
||||
sizes = [
|
||||
(10, 20), (32, 16), (123, 1500), (234, 189),
|
||||
]
|
||||
shapes = [
|
||||
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw(N: int, M: int):
|
||||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
|
||||
np.random.shuffle(steps)
|
||||
x = np.random.random((N, M)).astype(np.float32)
|
||||
|
||||
i, j, k = 0, 0, 0
|
||||
trace = []
|
||||
while True:
|
||||
x[i, j] -= 1
|
||||
trace.append((i, j))
|
||||
|
||||
if k == len(steps):
|
||||
break
|
||||
|
||||
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
|
||||
i += 1
|
||||
j += 1
|
||||
k += 2
|
||||
continue
|
||||
|
||||
if steps[k] == 0:
|
||||
i += 1
|
||||
if steps[k] == 1:
|
||||
j += 1
|
||||
k += 1
|
||||
|
||||
trace = np.array(trace).T
|
||||
dtw_trace = dtw_cpu(x)
|
||||
|
||||
assert np.allclose(trace, dtw_trace)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw_cuda_equivalence(N: int, M: int):
|
||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||
x_cuda = torch.from_numpy(x_numpy).cuda()
|
||||
|
||||
trace_cpu = dtw_cpu(x_numpy)
|
||||
trace_cuda = dtw_cuda(x_cuda)
|
||||
|
||||
assert np.allclose(trace_cpu, trace_cuda)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered = median_filter(x, filter_width)
|
||||
|
||||
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
|
||||
pad_width = filter_width // 2
|
||||
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
|
||||
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
|
||||
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
|
||||
|
||||
assert np.allclose(filtered, scipy_filtered)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter_equivalence(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered_cpu = median_filter(x, filter_width)
|
||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||
|
||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||
@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0)
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
|
||||
assert result["language"] == "en"
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
assert timing["start"] < timing["end"]
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
print(timing)
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
@ -29,6 +29,23 @@ _MODELS = {
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
|
||||
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
@ -18,6 +18,10 @@ CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
@ -8,8 +10,8 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -213,6 +215,15 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
||||
all_heads[self.dims.n_text_layer // 2:] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
||||
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
305
whisper/timing.py
Normal file
305
whisper/timing.py
Normal file
@ -0,0 +1,305 @@
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
@numba.jit
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence):-1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
word_durations = end_times - start_times
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
if len(word_durations) > 0:
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"\'“¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
||||
**hyperparams,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens = [t for segment in segments for t in segment["tokens"]]
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
|
||||
|
||||
for segment in segments:
|
||||
segment["words"] = []
|
||||
|
||||
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
|
||||
for i, timing in enumerate(alignment):
|
||||
if timing.word:
|
||||
segment = segments[token_sources[word_boundaries[i]]]
|
||||
start = round(time_offset + timing.start, 2)
|
||||
end = round(time_offset + timing.end, 2)
|
||||
segment["words"].append(
|
||||
dict(word=timing.word, start=start, end=end, probability=timing.probability)
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
if len(words := segment["words"]) > 0:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["start"] = words[0]["start"]
|
||||
segment["end"] = words[-1]["end"]
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, cached_property
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -265,6 +266,48 @@ class Tokenizer:
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
if "\ufffd" not in decoded:
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
|
||||
@ -7,8 +7,9 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
|
||||
|
||||
@ -27,6 +28,9 @@ def transcribe(
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"\'“¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -63,6 +67,21 @@ def transcribe(
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
@ -90,16 +109,19 @@ def transcribe(
|
||||
else:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
decode_result = None
|
||||
@ -145,42 +167,35 @@ def transcribe(
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
|
||||
return {
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": text_tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
|
||||
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
previous_seek_value = seek
|
||||
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek:]
|
||||
segment_size = min(mel_segment.shape[-1], N_FRAMES)
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(segment)
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
@ -191,29 +206,36 @@ def transcribe(
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
current_tokens = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
|
||||
consecutive = consecutive.tolist() + [len(tokens)]
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_pos * time_precision,
|
||||
end=timestamp_offset + end_timestamp_pos * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
current_segments.append(new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
))
|
||||
current_tokens.append(sliced_tokens.tolist())
|
||||
last_slice = current_slice
|
||||
|
||||
if ended_with_single_timestamp:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment.shape[-1]
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
@ -227,23 +249,54 @@ def transcribe(
|
||||
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
current_segments.append(new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
|
||||
seek += segment.shape[-1]
|
||||
all_tokens.extend(tokens.tolist())
|
||||
))
|
||||
current_tokens.append(tokens.tolist())
|
||||
seek += segment_size
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
|
||||
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
||||
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
current_tokens[i] = []
|
||||
|
||||
all_segments.extend(current_segments)
|
||||
all_tokens.extend([token for segment in current_tokens for token in segment])
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||
previous_seek_value = seek
|
||||
pbar.update(min(num_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
|
||||
@ -282,6 +335,9 @@ def cli():
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
92
whisper/triton_ops.py
Normal file
92
whisper/triton_ops.py
Normal file
@ -0,0 +1,92 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError:
|
||||
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < M
|
||||
|
||||
for k in range(1, N + M + 1): # k = i + j
|
||||
tl.debug_barrier()
|
||||
|
||||
p0 = cost + (k - 1) * cost_stride
|
||||
p1 = cost + k * cost_stride
|
||||
p2 = cost + k * cost_stride + 1
|
||||
|
||||
c0 = tl.load(p0 + offsets, mask=mask)
|
||||
c1 = tl.load(p1 + offsets, mask=mask)
|
||||
c2 = tl.load(p2 + offsets, mask=mask)
|
||||
|
||||
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
||||
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
||||
|
||||
cost_ptr = cost + (k + 1) * cost_stride + 1
|
||||
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
||||
|
||||
trace_ptr = trace + (k + 1) * trace_stride + 1
|
||||
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
||||
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
||||
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def median_kernel(filter_width: int):
|
||||
@triton.jit
|
||||
def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width
|
||||
row_idx = tl.program_id(0)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < y_stride
|
||||
|
||||
x_ptr = x + row_idx * x_stride
|
||||
y_ptr = y + row_idx * y_stride
|
||||
|
||||
LOAD_ALL_ROWS_HERE
|
||||
|
||||
BUBBLESORT_HERE
|
||||
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([
|
||||
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
||||
for i in range(filter_width)
|
||||
]))
|
||||
kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([
|
||||
"\n\n".join([
|
||||
"\n".join([
|
||||
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
||||
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
||||
f" row{j} = smaller",
|
||||
f" row{j + 1} = larger",
|
||||
])
|
||||
for j in range(filter_width - i - 1)
|
||||
])
|
||||
for i in range(filter_width // 2 + 1)
|
||||
]))
|
||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of given width along the last dimension of x"""
|
||||
slices = x.contiguous().unfold(-1, filter_width, 1)
|
||||
grid = np.prod(slices.shape[:-2])
|
||||
|
||||
kernel = median_kernel(filter_width)
|
||||
y = torch.empty_like(slices[..., 0])
|
||||
|
||||
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
||||
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return y
|
||||
@ -85,34 +85,63 @@ class WriteTXT(ResultWriter):
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteVTT(ResultWriter):
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict):
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment['text'].strip().replace('-->', '->')
|
||||
|
||||
if word_timings := segment.get("words", None):
|
||||
all_words = [timing["word"] for timing in word_timings]
|
||||
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
||||
last = segment_start
|
||||
for i, this_word in enumerate(word_timings):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, segment_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
|
||||
)
|
||||
last = end
|
||||
|
||||
if last != segment_end:
|
||||
yield last, segment_end, segment_text
|
||||
else:
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = '.'
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
for start, end, text in self.iterate_result(result):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(ResultWriter):
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ','
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for i, segment in enumerate(result["segments"], start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user