mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Merge ad84a5f266e4e274a0076fa8fcf98aec3ab7580f into cdb81479623391f0651f4f9175ad986e85777f31
This commit is contained in:
commit
4fbca552cb
@ -4,7 +4,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
from whisper.audio import CHUNK_LENGTH
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from whisper.transcribe import Transcriber
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
@ -40,3 +42,79 @@ def test_transcribe(model_name: str):
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def __init__(self, language, **kw):
|
||||
self.language, self._kw = language, kw
|
||||
for k, v in kw.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def encode(self, prompt):
|
||||
return [self.language, self, prompt]
|
||||
|
||||
|
||||
class OnDemand:
|
||||
def __init__(self, seq=(), relative=True):
|
||||
self.seq, self.relative = seq, relative
|
||||
self.prev, self.given = 0, 0
|
||||
|
||||
def __getitem__(self, key):
|
||||
_key = self.given if self.relative else key
|
||||
self.prev = (
|
||||
self.seq[_key]
|
||||
if _key < len(self.seq)
|
||||
else int(input(f"lang @ {_key}: ") or self.prev)
|
||||
)
|
||||
self.given += 1
|
||||
return self.prev
|
||||
|
||||
def __len__(self):
|
||||
return CHUNK_LENGTH + 2 if self.relative else len(self.seq)
|
||||
|
||||
|
||||
class TranscriberTest(Transcriber):
|
||||
sample = object()
|
||||
dtype = torch.float32
|
||||
model = type(
|
||||
"MockModel",
|
||||
(),
|
||||
{"is_multilingual": True, "num_languages": None, "device": torch.device("cpu")},
|
||||
)()
|
||||
_seek = 0
|
||||
|
||||
def __init__(self, seq=None):
|
||||
super().__init__(self.model, initial_prompt="")
|
||||
self.seq = OnDemand(seq or ())
|
||||
self.result = []
|
||||
self.latest = torch.zeros((0,))
|
||||
for i in range(len(self.seq)):
|
||||
self._seek = i
|
||||
self.frame_offset = max(0, i + 1 - CHUNK_LENGTH)
|
||||
res = self.initial_prompt_tokens
|
||||
assert res[0] == self.seq.prev
|
||||
self.result.append(res[1:])
|
||||
if seq is None:
|
||||
print(res)
|
||||
|
||||
def detect_language(self, mel=None):
|
||||
self.result.append([self.sample, mel])
|
||||
return self.seq[self._seek]
|
||||
|
||||
def get_tokenizer(self, multilingual, language, **kw):
|
||||
return MockTokenizer(language, **{"multilingual": multilingual, **kw})
|
||||
|
||||
@property
|
||||
def rle(self):
|
||||
res = []
|
||||
for i, *j in self.result:
|
||||
if i is self.sample:
|
||||
res.append(0)
|
||||
else:
|
||||
res[-1] += 1
|
||||
return res
|
||||
|
||||
|
||||
def test_language():
|
||||
res = TranscriberTest([0, 0, 1, 0, 0, 0, 0, 0, 0]).rle
|
||||
assert res == [1, 2, 1, 1, 2, 4, 8, 11, 2]
|
||||
|
||||
239
whisper/batching.py
Normal file
239
whisper/batching.py
Normal file
@ -0,0 +1,239 @@
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
||||
from typing import Generic, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor])
|
||||
|
||||
|
||||
class ArrayWrapper(Generic[A]):
|
||||
pass
|
||||
|
||||
|
||||
ArrayTypes = Union[A, ArrayWrapper[A]]
|
||||
|
||||
|
||||
class LoopbackIterator(Generic[A]):
|
||||
async def iter(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __aiter__(self):
|
||||
self._iter = self.iter()
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ArrayTypes:
|
||||
if not hasattr(self, "_iter"):
|
||||
self.__aiter__()
|
||||
return await anext(self._iter)
|
||||
|
||||
|
||||
async def empty():
|
||||
return
|
||||
yield
|
||||
|
||||
|
||||
class Unwrap(LoopbackIterator):
|
||||
_initial: Union[ArrayTypes, Awaitable[ArrayTypes]]
|
||||
started: bool
|
||||
iterator: AsyncIterable[ArrayTypes]
|
||||
|
||||
def __init__(self, iterator: AsyncIterable[ArrayTypes]):
|
||||
while isinstance(iterator, PassthroughTransform):
|
||||
iterator = iterator.handoff()
|
||||
if isinstance(iterator, Unwrap):
|
||||
self._initial, self.started = iterator.initial(), iterator.started
|
||||
self.iterator = iterator.iterator
|
||||
return
|
||||
elif not isinstance(iterator, AsyncIterator):
|
||||
iterator = aiter(iterator)
|
||||
try:
|
||||
self._initial = anext(iterator)
|
||||
self.iterator, self.started = iterator, False
|
||||
except StopAsyncIteration:
|
||||
self.iterator, self.started = empty(), True
|
||||
|
||||
async def initial(self) -> ArrayTypes:
|
||||
while isinstance(self._initial, Awaitable):
|
||||
self._initial = await self._initial
|
||||
return self._initial
|
||||
|
||||
async def iter(self) -> AsyncIterator[ArrayTypes]:
|
||||
if not self.started:
|
||||
self.started = True
|
||||
yield await self.initial()
|
||||
async for i in self.iterator:
|
||||
yield i
|
||||
|
||||
async def prop(self, key: str, default):
|
||||
if hasattr(self, "initial"):
|
||||
return getattr(await self.initial(), key)
|
||||
else:
|
||||
return default
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.prop("shape", ())
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.prop("dtype", None)
|
||||
|
||||
@property
|
||||
async def concat(self):
|
||||
return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat
|
||||
|
||||
|
||||
class PassthroughTransform(LoopbackIterator):
|
||||
def handoff(self) -> AsyncIterable[ArrayTypes]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BoxedIterator(PassthroughTransform):
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iterator
|
||||
self.flag = object()
|
||||
|
||||
def handoff(self) -> AsyncIterable[ArrayTypes]:
|
||||
self.flag = None
|
||||
return self.iterator
|
||||
|
||||
async def iter(self) -> AsyncIterator[ArrayTypes]:
|
||||
if self.flag is None:
|
||||
raise Exception("iterator source removed")
|
||||
self.flag = flag = object()
|
||||
async for i in self.iterator:
|
||||
yield i
|
||||
if self.flag != flag:
|
||||
raise Exception("source can only be used by one iterator")
|
||||
|
||||
|
||||
def LookAlong(axis: int):
|
||||
assert axis >= 0
|
||||
empties = (slice(None),) * axis
|
||||
|
||||
class LookAlong(ArrayWrapper):
|
||||
def __init__(self, value: A):
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.value.shape[axis]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.value[empties + (idx,)]
|
||||
|
||||
def __next__(self):
|
||||
return self.value
|
||||
|
||||
return LookAlong
|
||||
|
||||
|
||||
class PassthroughMap(PassthroughTransform):
|
||||
def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]):
|
||||
self.iterator, self.apply = iterator, apply
|
||||
|
||||
def handoff(self) -> AsyncIterator[A]:
|
||||
return self.iterator
|
||||
|
||||
async def iter(self) -> AsyncIterator[ArrayTypes]:
|
||||
async for i in self.iterator:
|
||||
yield self.apply(i)
|
||||
|
||||
|
||||
class Group:
|
||||
def __init__(self, concat, axis=-1):
|
||||
self.concat = concat
|
||||
self.holding = []
|
||||
self.consumed = 0
|
||||
self.shape = 0
|
||||
|
||||
def add(self, value):
|
||||
self.holding.append(value)
|
||||
self.shape += value.shape
|
||||
|
||||
def take(self, amount, exact=True):
|
||||
assert amount > 0 and amount <= self.shape
|
||||
self.shape -= amount
|
||||
taking, start = -self.consumed, self.consumed
|
||||
for i, x in enumerate(self.holding):
|
||||
taking += x.shape
|
||||
if taking >= amount:
|
||||
self.consumed = amount - taking + x.shape
|
||||
break
|
||||
if taking == amount or not exact:
|
||||
self.shape += amount - taking
|
||||
self.consumed = 0
|
||||
res = self.concat(
|
||||
[self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]]
|
||||
)
|
||||
self.holding = self.holding[i + 1 :]
|
||||
return res
|
||||
if i == 0:
|
||||
return self.holding[0][start : self.consumed]
|
||||
res = self.concat(
|
||||
[self.holding[0][start:]]
|
||||
+ [i.value for i in self.holding[1:i]]
|
||||
+ [self.holding[i][: self.consumed]]
|
||||
)
|
||||
self.holding = self.holding[i:]
|
||||
return res
|
||||
|
||||
def all(self):
|
||||
res = self.concat([i.value for i in self.holding])
|
||||
self.shape = 0
|
||||
self.consumed = 0
|
||||
self.holding = []
|
||||
return res
|
||||
|
||||
|
||||
class Taken:
|
||||
def take(self, *a, **kw):
|
||||
raise Exception("batch queue moved")
|
||||
|
||||
|
||||
class Batcher(PassthroughTransform):
|
||||
def __init__(self, iterator, size, axis=-1, exact=False):
|
||||
assert isinstance(size, int) and size > 0
|
||||
self.size, self._axis, self.exact = size, axis, exact
|
||||
if isinstance(iterator, Unwrap) and hasattr(iterator, "group"):
|
||||
self.group = iterator.group
|
||||
self.preview = Unwrap(iterator)
|
||||
|
||||
async def concat(self):
|
||||
f = await self.preview.concat
|
||||
return lambda tensors: f(tensors, self.axis)
|
||||
|
||||
_iterator = None
|
||||
|
||||
async def iterator(self):
|
||||
if self._iterator is None:
|
||||
self.axis = (
|
||||
len(await self.preview.shape) + self._axis
|
||||
if self._axis < 0
|
||||
else self._axis
|
||||
)
|
||||
if not hasattr(self, "group"):
|
||||
self.group = Group(await self.concat())
|
||||
self._iterator = PassthroughMap(
|
||||
LookAlong(self.axis), BoxedIterator(self.preview)
|
||||
)
|
||||
return self._iterator
|
||||
|
||||
def handoff(self):
|
||||
self.group = Taken()
|
||||
return self.preview if self._iterator is None else self._iterator
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
iterator = aiter(await self.iterator())
|
||||
while self.group.shape < self.size:
|
||||
try:
|
||||
self.group.add(await anext(iterator))
|
||||
except StopAsyncIteration:
|
||||
if self.group.shape > 0:
|
||||
return self.group.all()
|
||||
raise
|
||||
return self.group.take(self.size, self.exact)
|
||||
309
whisper/buffer.py
Normal file
309
whisper/buffer.py
Normal file
@ -0,0 +1,309 @@
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine
|
||||
from typing import IO, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters
|
||||
from .batching import Batcher
|
||||
from .utils import PathType, ceildiv
|
||||
|
||||
|
||||
class AudioSink:
|
||||
def __init__(self, *, rate: int = SAMPLE_RATE, **kw):
|
||||
super().__init__(**kw)
|
||||
self.rate = rate
|
||||
|
||||
def read(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ArrayStream(AudioSink):
|
||||
q: asyncio.Queue
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
batch: int = 1,
|
||||
n_mels: int = 80,
|
||||
capacity: int = 1_000_000,
|
||||
**kw,
|
||||
):
|
||||
super().__init__(**kw)
|
||||
self.q = asyncio.Queue(capacity)
|
||||
self.finished = asyncio.Event()
|
||||
self.device, self.batch, self.n_mels = device, batch, n_mels
|
||||
self.sees = self.zeros((0,))
|
||||
self.spectogram = self.zeros((n_mels, 0))
|
||||
self.hann = torch.hann_window(N_FFT).to(self.sees.device)
|
||||
self.filters = mel_filters(self.sees.device, n_mels)
|
||||
|
||||
def zeros(self, shape):
|
||||
return torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
|
||||
write_blockable: bool = True
|
||||
|
||||
def write(self, data: bytes) -> Optional[Coroutine]:
|
||||
if self.write_blockable:
|
||||
return self.q.put(data)
|
||||
else:
|
||||
self.q.put_nowait(data)
|
||||
return None
|
||||
|
||||
def load(self, data: bytes) -> np.ndarray:
|
||||
return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]:
|
||||
async for data in iterator:
|
||||
yield self.load(data)
|
||||
|
||||
async def buffer(self) -> AsyncIterator[bytes]:
|
||||
waiter = asyncio.create_task(self.finished.wait())
|
||||
while not self.finished.is_set():
|
||||
getter = asyncio.create_task(self.q.get())
|
||||
done, pending = await asyncio.wait(
|
||||
(waiter, getter), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if getter in done:
|
||||
yield getter.result()
|
||||
while not self.q.empty():
|
||||
yield self.q.get_nowait()
|
||||
|
||||
async def buffer_nowait(self) -> AsyncIterator[bytes]:
|
||||
try:
|
||||
while True:
|
||||
yield self.q.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
loading: Optional[Batcher] = None
|
||||
|
||||
async def fft_offset(
|
||||
self, iterator: AsyncIterable[bytes]
|
||||
) -> AsyncIterator[np.ndarray]:
|
||||
init = self.loader(iterator) if self.loading is None else self.loading
|
||||
self.loading = Batcher(init, HOP_LENGTH)
|
||||
_iterator = aiter(self.loading)
|
||||
window = np.zeros((0,), dtype=np.float32)
|
||||
while window.size < ceildiv(N_FFT, 2):
|
||||
try:
|
||||
window = np.concatenate((window, await anext(_iterator)))
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
window = np.pad(window, (N_FFT // 2, 0), "reflect")
|
||||
yield window
|
||||
async for data in _iterator:
|
||||
yield data
|
||||
# for _ in range(N_FFT // HOP_LENGTH):
|
||||
# yield np.zeros((HOP_LENGTH,), dtype=np.float32)
|
||||
# (done by runoff)
|
||||
|
||||
def seeing(self, sees: torch.Tensor) -> torch.Tensor:
|
||||
hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH
|
||||
return sees[hopped:]
|
||||
|
||||
async def window(
|
||||
self, iterator: AsyncIterable[bytes]
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
_iterator = self.fft_offset(iterator)
|
||||
async for data in _iterator:
|
||||
_data = torch.from_numpy(data)
|
||||
prev = self.sees.shape[0] - N_FFT
|
||||
while (_data.shape[0] + prev) // HOP_LENGTH < self.batch - 1:
|
||||
try:
|
||||
adding = torch.from_numpy(await anext(_iterator))
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
_data = torch.cat((_data, adding))
|
||||
if self.device is not None:
|
||||
_data.to(self.device)
|
||||
res = torch.cat((self.sees, _data))
|
||||
self.sees = self.seeing(res)
|
||||
yield self.transform(self.dft(res))
|
||||
|
||||
def dft(self, amp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stft(
|
||||
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
|
||||
)
|
||||
|
||||
log_spec_bound: Optional[torch.Tensor] = None
|
||||
|
||||
def transform(self, stft: torch.Tensor) -> torch.Tensor:
|
||||
magnitudes = stft.abs() ** 2
|
||||
mel_spec = self.filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
# causes values to not precisely match the original
|
||||
self.log_spec_bound = (
|
||||
log_spec.max()
|
||||
if self.log_spec_bound is None
|
||||
else torch.maximum(log_spec.max(), self.log_spec_bound)
|
||||
)
|
||||
log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
def padding(self, content_frames: int) -> int:
|
||||
return N_FRAMES
|
||||
|
||||
# dft_pad: add ending content frames to match padding from a centered STFT
|
||||
dft_pad: bool = False
|
||||
|
||||
def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor:
|
||||
dft_pad = self.dft_pad if dft_pad is None else dft_pad
|
||||
if dft_pad:
|
||||
overrun = (ceildiv(N_FFT, HOP_LENGTH) - 1) * HOP_LENGTH
|
||||
spectogram = torch.cat((self.sees, self.zeros(overrun)))
|
||||
if spectogram.shape[-1] >= N_FFT:
|
||||
spectogram = self.transform(self.dft(spectogram))
|
||||
else:
|
||||
spectogram = torch.zeros(0)
|
||||
padding = self.padding(self.spectogram.shape[-1] + spectogram.shape[-1])
|
||||
pad = self.zeros((self.n_mels, max(0, padding)))
|
||||
spectogram = torch.cat((self.spectogram, spectogram, pad), -1)
|
||||
return spectogram if padding >= 0 else spectogram[-padding:]
|
||||
|
||||
offset: int = 0
|
||||
|
||||
async def pull(self) -> torch.Tensor:
|
||||
context = self.spectogram.shape[-1]
|
||||
iterator = self.window(self.buffer_nowait())
|
||||
async for frame in iterator:
|
||||
self.spectogram = torch.cat((self.spectogram, frame), -1)
|
||||
cutoff = min(context, max(self.spectogram.shape[-1] - N_FRAMES, 0))
|
||||
self.offset += cutoff
|
||||
self.spectogram = self.spectogram[:, cutoff:]
|
||||
return self.runoff()
|
||||
|
||||
staging: Optional[Batcher] = None
|
||||
|
||||
async def _push(
|
||||
self, sec: float, exact: bool = False
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
batching = int(sec * SAMPLE_RATE // HOP_LENGTH)
|
||||
init = self.window(self.buffer()) if self.staging is None else self.staging
|
||||
self.staging = Batcher(init, batching, exact=exact)
|
||||
async for frame in self.staging:
|
||||
batched = batching if exact else frame.shape[-1]
|
||||
cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0)
|
||||
self.offset += cutoff
|
||||
self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1)
|
||||
yield self.runoff()
|
||||
|
||||
reader: Optional[Awaitable] = None
|
||||
|
||||
def start(self, **kw) -> None:
|
||||
if self.reader is None:
|
||||
self.reader = asyncio.create_task(self.read(**kw))
|
||||
|
||||
async def push(
|
||||
self, sec: float, exact: bool = False, **kw
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
self.start(**kw)
|
||||
async for i in self._push(sec, exact):
|
||||
yield i
|
||||
assert self.reader is not None
|
||||
await self.reader
|
||||
|
||||
async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor:
|
||||
try:
|
||||
return await anext(self.push(sec, exact))
|
||||
except StopAsyncIteration:
|
||||
if self.reader is not None:
|
||||
await self.reader
|
||||
return self.zeros((self.n_mels, 0))
|
||||
|
||||
async def full(self, **kw) -> torch.Tensor:
|
||||
await self.read(**kw)
|
||||
return await self.pull()
|
||||
|
||||
def sequential(self, **kw) -> torch.Tensor:
|
||||
return asyncio.run(self.full(**kw))
|
||||
|
||||
async def amplitudes(self, **kw) -> np.ndarray:
|
||||
self.start(**kw)
|
||||
res = []
|
||||
async for data in self.loader(self.buffer()):
|
||||
res.append(data)
|
||||
assert self.reader is not None
|
||||
await self.reader
|
||||
return np.concatenate(res)
|
||||
|
||||
def all_amplitudes(self, **kw) -> np.ndarray:
|
||||
return asyncio.run(self.amplitudes(**kw))
|
||||
|
||||
|
||||
class RawAudioFile(ArrayStream):
|
||||
def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw):
|
||||
super().__init__(**kw)
|
||||
self.fname = fname
|
||||
self.period = period
|
||||
|
||||
fp: Optional[IO[bytes]] = None
|
||||
|
||||
async def read(self) -> None:
|
||||
fp = open(self.fname, "rb") if self.fp is None else self.fp
|
||||
data = fp.read(self.period)
|
||||
while len(data) != 0:
|
||||
io_hold = self.write(data)
|
||||
assert io_hold is not None and self.write_blockable is True
|
||||
await io_hold
|
||||
data = fp.read(self.period)
|
||||
self.finished.set()
|
||||
|
||||
|
||||
class AudioFile(RawAudioFile):
|
||||
def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw):
|
||||
assert not subprocess.run(
|
||||
["which", "ffmpeg"], stdout=subprocess.PIPE
|
||||
).returncode
|
||||
super().__init__(period=period or -1, fname=fname, **kw)
|
||||
|
||||
async def read(self) -> None:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"0",
|
||||
"-i",
|
||||
self.fname,
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(self.rate),
|
||||
"-",
|
||||
]
|
||||
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
self.fp = ps.stdout
|
||||
await super().read()
|
||||
_, stderr = ps.communicate()
|
||||
if ps.returncode not in (None, 0):
|
||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-hide_banner",
|
||||
"-show_format",
|
||||
"-of",
|
||||
"json",
|
||||
"-i",
|
||||
self.fname,
|
||||
]
|
||||
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = ps.communicate()
|
||||
if ps.returncode not in (None, 0):
|
||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||
return float(json.loads(stdout)["format"]["duration"])
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@ -21,11 +23,19 @@ else:
|
||||
return string
|
||||
|
||||
|
||||
PathType = Union[str, pathlib.Path]
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
return x // y
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/17511341/3476782
|
||||
def ceildiv(a: Union[int, float], b: Union[int, float]) -> int:
|
||||
return int(-(a // -b))
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
@ -68,6 +78,20 @@ def format_timestamp(
|
||||
)
|
||||
|
||||
|
||||
def hms(sec: float) -> str:
|
||||
trim = sec < 3600
|
||||
h = "" if trim else str(int(sec) // 3600) + ":"
|
||||
m_fill = " " if trim else "0"
|
||||
m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":"
|
||||
s = str(int(sec) % 60).rjust(2, "0") + "."
|
||||
c = str(round((sec % 1) * 100)).rjust(2, "0")
|
||||
return h + m + s + c
|
||||
|
||||
|
||||
def tod(seconds: float) -> str:
|
||||
return time.strftime("%H:%M:%S", time.localtime(seconds))
|
||||
|
||||
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
@ -314,3 +338,47 @@ def get_writer(
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# boilerplate for property with _{name} storage and passthrough getter/setter
|
||||
class PassthroughProperty(Generic[T]):
|
||||
def __init__(self, default: T):
|
||||
self.value = default
|
||||
|
||||
f: Optional[Callable[[Any, T], None]] = None
|
||||
|
||||
def setter(self, f: Callable[[Any, T], None]):
|
||||
self.f = f
|
||||
return self
|
||||
|
||||
g: Optional[property] = None
|
||||
|
||||
def property(self, g: Callable[[Any], T]):
|
||||
self.g = property(g)
|
||||
return self
|
||||
|
||||
|
||||
class PassthroughPropertyDefaults(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
def closure(f, v):
|
||||
def prop(self):
|
||||
return getattr(self, v)
|
||||
|
||||
def setter(self, value):
|
||||
setattr(self, v, value)
|
||||
|
||||
prop.__name__ = setter.__name__ = f
|
||||
return property(prop), setter
|
||||
|
||||
updates = {}
|
||||
for k, v in attrs.items():
|
||||
if not isinstance(v, PassthroughProperty):
|
||||
continue
|
||||
private = "_" + k
|
||||
updates[private] = v.value
|
||||
getter, setter = closure(k, private)
|
||||
updates[k] = (v.g or getter).setter(v.f or setter)
|
||||
return super().__new__(cls, clsname, bases, {**attrs, **updates})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user