Merge ad84a5f266e4e274a0076fa8fcf98aec3ab7580f into cdb81479623391f0651f4f9175ad986e85777f31

This commit is contained in:
Kent Slaney 2024-10-26 14:41:15 +02:00 committed by GitHub
commit 4fbca552cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1510 additions and 451 deletions

View File

@ -4,7 +4,9 @@ import pytest
import torch import torch
import whisper import whisper
from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import get_tokenizer from whisper.tokenizer import get_tokenizer
from whisper.transcribe import Transcriber
@pytest.mark.parametrize("model_name", whisper.available_models()) @pytest.mark.parametrize("model_name", whisper.available_models())
@ -40,3 +42,79 @@ def test_transcribe(model_name: str):
timing_checked = True timing_checked = True
assert timing_checked assert timing_checked
class MockTokenizer:
def __init__(self, language, **kw):
self.language, self._kw = language, kw
for k, v in kw.items():
setattr(self, k, v)
def encode(self, prompt):
return [self.language, self, prompt]
class OnDemand:
def __init__(self, seq=(), relative=True):
self.seq, self.relative = seq, relative
self.prev, self.given = 0, 0
def __getitem__(self, key):
_key = self.given if self.relative else key
self.prev = (
self.seq[_key]
if _key < len(self.seq)
else int(input(f"lang @ {_key}: ") or self.prev)
)
self.given += 1
return self.prev
def __len__(self):
return CHUNK_LENGTH + 2 if self.relative else len(self.seq)
class TranscriberTest(Transcriber):
sample = object()
dtype = torch.float32
model = type(
"MockModel",
(),
{"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")},
)()
_seek = 0
def __init__(self, seq=None):
super().__init__(self.model, initial_prompt="")
self.seq = OnDemand(seq or ())
self.result = []
self.latest = torch.zeros((0,))
for i in range(len(self.seq)):
self._seek = i
self.frame_offset = max(0, i + 1 - CHUNK_LENGTH)
res = self.initial_prompt_tokens
assert res[0] == self.seq.prev
self.result.append(res[1:])
if seq is None:
print(res)
def detect_language(self, mel=None):
self.result.append([self.sample, mel])
return self.seq[self._seek]
def get_tokenizer(self, multilingual, language, **kw):
return MockTokenizer(language, **{"multilingual": multilingual, **kw})
@property
def rle(self):
res = []
for i, *j in self.result:
if i is self.sample:
res.append(0)
else:
res[-1] += 1
return res
def test_language():
res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle
assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2]

239
whisper/batching.py Normal file
View File

@ -0,0 +1,239 @@
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Generic, TypeVar, Union
import numpy as np
import torch
A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor])
class ArrayWrapper(Generic[A]):
pass
ArrayTypes = Union[A, ArrayWrapper[A]]
class LoopbackIterator(Generic[A]):
async def iter(self):
raise NotImplementedError
def __aiter__(self):
self._iter = self.iter()
return self
async def __anext__(self) -> ArrayTypes:
if not hasattr(self, "_iter"):
self.__aiter__()
return await anext(self._iter)
async def empty():
return
yield
class Unwrap(LoopbackIterator):
_initial: Union[ArrayTypes, Awaitable[ArrayTypes]]
started: bool
iterator: AsyncIterable[ArrayTypes]
def __init__(self, iterator: AsyncIterable[ArrayTypes]):
while isinstance(iterator, PassthroughTransform):
iterator = iterator.handoff()
if isinstance(iterator, Unwrap):
self._initial, self.started = iterator.initial(), iterator.started
self.iterator = iterator.iterator
return
elif not isinstance(iterator, AsyncIterator):
iterator = aiter(iterator)
try:
self._initial = anext(iterator)
self.iterator, self.started = iterator, False
except StopAsyncIteration:
self.iterator, self.started = empty(), True
async def initial(self) -> ArrayTypes:
while isinstance(self._initial, Awaitable):
self._initial = await self._initial
return self._initial
async def iter(self) -> AsyncIterator[ArrayTypes]:
if not self.started:
self.started = True
yield await self.initial()
async for i in self.iterator:
yield i
async def prop(self, key: str, default):
if hasattr(self, "initial"):
return getattr(await self.initial(), key)
else:
return default
@property
def shape(self):
return self.prop("shape", ())
@property
def dtype(self):
return self.prop("dtype", None)
@property
async def concat(self):
return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat
class PassthroughTransform(LoopbackIterator):
def handoff(self) -> AsyncIterable[ArrayTypes]:
raise NotImplementedError
class BoxedIterator(PassthroughTransform):
def __init__(self, iterator):
self.iterator = iterator
self.flag = object()
def handoff(self) -> AsyncIterable[ArrayTypes]:
self.flag = None
return self.iterator
async def iter(self) -> AsyncIterator[ArrayTypes]:
if self.flag is None:
raise Exception("iterator source removed")
self.flag = flag = object()
async for i in self.iterator:
yield i
if self.flag != flag:
raise Exception("source can only be used by one iterator")
def LookAlong(axis: int):
assert axis >= 0
empties = (slice(None),) * axis
class LookAlong(ArrayWrapper):
def __init__(self, value: A):
self.value = value
@property
def shape(self):
return self.value.shape[axis]
def __getitem__(self, idx):
return self.value[empties + (idx,)]
def __next__(self):
return self.value
return LookAlong
class PassthroughMap(PassthroughTransform):
def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]):
self.iterator, self.apply = iterator, apply
def handoff(self) -> AsyncIterator[A]:
return self.iterator
async def iter(self) -> AsyncIterator[ArrayTypes]:
async for i in self.iterator:
yield self.apply(i)
class Group:
def __init__(self, concat, axis=-1):
self.concat = concat
self.holding = []
self.consumed = 0
self.shape = 0
def add(self, value):
self.holding.append(value)
self.shape += value.shape
def take(self, amount, exact=True):
assert amount > 0 and amount <= self.shape
self.shape -= amount
taking, start = -self.consumed, self.consumed
for i, x in enumerate(self.holding):
taking += x.shape
if taking >= amount:
self.consumed = amount - taking + x.shape
break
if taking == amount or not exact:
self.shape += amount - taking
self.consumed = 0
res = self.concat(
[self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]]
)
self.holding = self.holding[i + 1 :]
return res
if i == 0:
return self.holding[0][start : self.consumed]
res = self.concat(
[self.holding[0][start:]]
+ [i.value for i in self.holding[1:i]]
+ [self.holding[i][: self.consumed]]
)
self.holding = self.holding[i:]
return res
def all(self):
res = self.concat([i.value for i in self.holding])
self.shape = 0
self.consumed = 0
self.holding = []
return res
class Taken:
def take(self, *a, **kw):
raise Exception("batch queue moved")
class Batcher(PassthroughTransform):
def __init__(self, iterator, size, axis=-1, exact=False):
assert isinstance(size, int) and size > 0
self.size, self._axis, self.exact = size, axis, exact
if isinstance(iterator, Unwrap) and hasattr(iterator, "group"):
self.group = iterator.group
self.preview = Unwrap(iterator)
async def concat(self):
f = await self.preview.concat
return lambda tensors: f(tensors, self.axis)
_iterator = None
async def iterator(self):
if self._iterator is None:
self.axis = (
len(await self.preview.shape) + self._axis
if self._axis < 0
else self._axis
)
if not hasattr(self, "group"):
self.group = Group(await self.concat())
self._iterator = PassthroughMap(
LookAlong(self.axis), BoxedIterator(self.preview)
)
return self._iterator
def handoff(self):
self.group = Taken()
return self.preview if self._iterator is None else self._iterator
def __aiter__(self):
return self
async def __anext__(self):
iterator = aiter(await self.iterator())
while self.group.shape < self.size:
try:
self.group.add(await anext(iterator))
except StopAsyncIteration:
if self.group.shape > 0:
return self.group.all()
raise
return self.group.take(self.size, self.exact)

309
whisper/buffer.py Normal file
View File

@ -0,0 +1,309 @@
import asyncio
import json
import subprocess
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine
from typing import IO, Optional, Union
import numpy as np
import torch
from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters
from .batching import Batcher
from .utils import PathType, ceildiv
class AudioSink:
def __init__(self, *, rate: int = SAMPLE_RATE, **kw):
super().__init__(**kw)
self.rate = rate
def read(self):
raise NotImplementedError
def write(self, data):
raise NotImplementedError
class ArrayStream(AudioSink):
q: asyncio.Queue
def __init__(
self,
*,
device: Optional[Union[str, torch.device]] = None,
batch: int = 1,
n_mels: int = 80,
capacity: int = 1_000_000,
**kw,
):
super().__init__(**kw)
self.q = asyncio.Queue(capacity)
self.finished = asyncio.Event()
self.device, self.batch, self.n_mels = device, batch, n_mels
self.sees = self.zeros((0,))
self.spectogram = self.zeros((n_mels, 0))
self.hann = torch.hann_window(N_FFT).to(self.sees.device)
self.filters = mel_filters(self.sees.device, n_mels)
def zeros(self, shape):
return torch.zeros(shape, dtype=torch.float32, device=self.device)
write_blockable: bool = True
def write(self, data: bytes) -> Optional[Coroutine]:
if self.write_blockable:
return self.q.put(data)
else:
self.q.put_nowait(data)
return None
def load(self, data: bytes) -> np.ndarray:
return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]:
async for data in iterator:
yield self.load(data)
async def buffer(self) -> AsyncIterator[bytes]:
waiter = asyncio.create_task(self.finished.wait())
while not self.finished.is_set():
getter = asyncio.create_task(self.q.get())
done, pending = await asyncio.wait(
(waiter, getter), return_when=asyncio.FIRST_COMPLETED
)
if getter in done:
yield getter.result()
while not self.q.empty():
yield self.q.get_nowait()
async def buffer_nowait(self) -> AsyncIterator[bytes]:
try:
while True:
yield self.q.get_nowait()
except asyncio.QueueEmpty:
pass
loading: Optional[Batcher] = None
async def fft_offset(
self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[np.ndarray]:
init = self.loader(iterator) if self.loading is None else self.loading
self.loading = Batcher(init, HOP_LENGTH)
_iterator = aiter(self.loading)
window = np.zeros((0,), dtype=np.float32)
while window.size < ceildiv(N_FFT, 2):
try:
window = np.concatenate((window, await anext(_iterator)))
except StopAsyncIteration:
return
window = np.pad(window, (N_FFT // 2, 0), "reflect")
yield window
async for data in _iterator:
yield data
# for _ in range(N_FFT // HOP_LENGTH):
# yield np.zeros((HOP_LENGTH,), dtype=np.float32)
# (done by runoff)
def seeing(self, sees: torch.Tensor) -> torch.Tensor:
hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH
return sees[hopped:]
async def window(
self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[torch.Tensor]:
_iterator = self.fft_offset(iterator)
async for data in _iterator:
_data = torch.from_numpy(data)
prev = self.sees.shape[0] - N_FFT
while (_data.shape[0] + prev) // HOP_LENGTH < self.batch - 1:
try:
adding = torch.from_numpy(await anext(_iterator))
except StopAsyncIteration:
break
_data = torch.cat((_data, adding))
if self.device is not None:
_data.to(self.device)
res = torch.cat((self.sees, _data))
self.sees = self.seeing(res)
yield self.transform(self.dft(res))
def dft(self, amp: torch.Tensor) -> torch.Tensor:
return torch.stft(
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
)
log_spec_bound: Optional[torch.Tensor] = None
def transform(self, stft: torch.Tensor) -> torch.Tensor:
magnitudes = stft.abs() ** 2
mel_spec = self.filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
# causes values to not precisely match the original
self.log_spec_bound = (
log_spec.max()
if self.log_spec_bound is None
else torch.maximum(log_spec.max(), self.log_spec_bound)
)
log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def padding(self, content_frames: int) -> int:
return N_FRAMES
# dft_pad: add ending content frames to match padding from a centered STFT
dft_pad: bool = False
def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor:
dft_pad = self.dft_pad if dft_pad is None else dft_pad
if dft_pad:
overrun = (ceildiv(N_FFT, HOP_LENGTH) - 1) * HOP_LENGTH
spectogram = torch.cat((self.sees, self.zeros(overrun)))
if spectogram.shape[-1] >= N_FFT:
spectogram = self.transform(self.dft(spectogram))
else:
spectogram = torch.zeros(0)
padding = self.padding(self.spectogram.shape[-1] + spectogram.shape[-1])
pad = self.zeros((self.n_mels, max(0, padding)))
spectogram = torch.cat((self.spectogram, spectogram, pad), -1)
return spectogram if padding >= 0 else spectogram[-padding:]
offset: int = 0
async def pull(self) -> torch.Tensor:
context = self.spectogram.shape[-1]
iterator = self.window(self.buffer_nowait())
async for frame in iterator:
self.spectogram = torch.cat((self.spectogram, frame), -1)
cutoff = min(context, max(self.spectogram.shape[-1] - N_FRAMES, 0))
self.offset += cutoff
self.spectogram = self.spectogram[:, cutoff:]
return self.runoff()
staging: Optional[Batcher] = None
async def _push(
self, sec: float, exact: bool = False
) -> AsyncIterator[torch.Tensor]:
batching = int(sec * SAMPLE_RATE // HOP_LENGTH)
init = self.window(self.buffer()) if self.staging is None else self.staging
self.staging = Batcher(init, batching, exact=exact)
async for frame in self.staging:
batched = batching if exact else frame.shape[-1]
cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0)
self.offset += cutoff
self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1)
yield self.runoff()
reader: Optional[Awaitable] = None
def start(self, **kw) -> None:
if self.reader is None:
self.reader = asyncio.create_task(self.read(**kw))
async def push(
self, sec: float, exact: bool = False, **kw
) -> AsyncIterator[torch.Tensor]:
self.start(**kw)
async for i in self._push(sec, exact):
yield i
assert self.reader is not None
await self.reader
async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor:
try:
return await anext(self.push(sec, exact))
except StopAsyncIteration:
if self.reader is not None:
await self.reader
return self.zeros((self.n_mels, 0))
async def full(self, **kw) -> torch.Tensor:
await self.read(**kw)
return await self.pull()
def sequential(self, **kw) -> torch.Tensor:
return asyncio.run(self.full(**kw))
async def amplitudes(self, **kw) -> np.ndarray:
self.start(**kw)
res = []
async for data in self.loader(self.buffer()):
res.append(data)
assert self.reader is not None
await self.reader
return np.concatenate(res)
def all_amplitudes(self, **kw) -> np.ndarray:
return asyncio.run(self.amplitudes(**kw))
class RawAudioFile(ArrayStream):
def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw):
super().__init__(**kw)
self.fname = fname
self.period = period
fp: Optional[IO[bytes]] = None
async def read(self) -> None:
fp = open(self.fname, "rb") if self.fp is None else self.fp
data = fp.read(self.period)
while len(data) != 0:
io_hold = self.write(data)
assert io_hold is not None and self.write_blockable is True
await io_hold
data = fp.read(self.period)
self.finished.set()
class AudioFile(RawAudioFile):
def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw):
assert not subprocess.run(
["which", "ffmpeg"], stdout=subprocess.PIPE
).returncode
super().__init__(period=period or -1, fname=fname, **kw)
async def read(self) -> None:
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
self.fname,
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(self.rate),
"-",
]
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self.fp = ps.stdout
await super().read()
_, stderr = ps.communicate()
if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
@property
def duration(self):
cmd = [
"ffprobe",
"-hide_banner",
"-show_format",
"-of",
"json",
"-i",
self.fname,
]
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = ps.communicate()
if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
return float(json.loads(stdout)["format"]["duration"])

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
import json import json
import os import os
import pathlib
import re import re
import sys import sys
import time
import zlib import zlib
from typing import Callable, List, Optional, TextIO from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@ -21,11 +23,19 @@ else:
return string return string
PathType = Union[str, pathlib.Path]
def exact_div(x, y): def exact_div(x, y):
assert x % y == 0 assert x % y == 0
return x // y return x // y
# https://stackoverflow.com/a/17511341/3476782
def ceildiv(a: Union[int, float], b: Union[int, float]) -> int:
return int(-(a // -b))
def str2bool(string): def str2bool(string):
str2val = {"True": True, "False": False} str2val = {"True": True, "False": False}
if string in str2val: if string in str2val:
@ -68,6 +78,20 @@ def format_timestamp(
) )
def hms(sec: float) -> str:
trim = sec < 3600
h = "" if trim else str(int(sec) // 3600) + ":"
m_fill = " " if trim else "0"
m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":"
s = str(int(sec) % 60).rjust(2, "0") + "."
c = str(round((sec % 1) * 100)).rjust(2, "0")
return h + m + s + c
def tod(seconds: float) -> str:
return time.strftime("%H:%M:%S", time.localtime(seconds))
def get_start(segments: List[dict]) -> Optional[float]: def get_start(segments: List[dict]) -> Optional[float]:
return next( return next(
(w["start"] for s in segments for w in s["words"]), (w["start"] for s in segments for w in s["words"]),
@ -314,3 +338,47 @@ def get_writer(
return write_all return write_all
return writers[output_format](output_dir) return writers[output_format](output_dir)
T = TypeVar("T")
# boilerplate for property with _{name} storage and passthrough getter/setter
class PassthroughProperty(Generic[T]):
def __init__(self, default: T):
self.value = default
f: Optional[Callable[[Any, T], None]] = None
def setter(self, f: Callable[[Any, T], None]):
self.f = f
return self
g: Optional[property] = None
def property(self, g: Callable[[Any], T]):
self.g = property(g)
return self
class PassthroughPropertyDefaults(type):
def __new__(cls, clsname, bases, attrs):
def closure(f, v):
def prop(self):
return getattr(self, v)
def setter(self, value):
setattr(self, v, value)
prop.__name__ = setter.__name__ = f
return property(prop), setter
updates = {}
for k, v in attrs.items():
if not isinstance(v, PassthroughProperty):
continue
private = "_" + k
updates[private] = v.value
getter, setter = closure(k, private)
updates[k] = (v.g or getter).setter(v.f or setter)
return super().__new__(cls, clsname, bases, {**attrs, **updates})