pre-commit formatting

This commit is contained in:
Kent Slaney 2024-07-14 19:07:06 -06:00
parent e0704ddeba
commit 0621ed8094
4 changed files with 328 additions and 245 deletions

View File

@ -1,15 +1,19 @@
import torch from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
import numpy as np
from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable
from typing import Generic, TypeVar, Union from typing import Generic, TypeVar, Union
import numpy as np
import torch
A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor]) A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor])
class ArrayWrapper(Generic[A]): class ArrayWrapper(Generic[A]):
pass pass
ArrayTypes = Union[A, ArrayWrapper[A]] ArrayTypes = Union[A, ArrayWrapper[A]]
class LoopbackIterator(Generic[A]): class LoopbackIterator(Generic[A]):
async def iter(self): async def iter(self):
raise NotImplementedError raise NotImplementedError
@ -23,14 +27,17 @@ class LoopbackIterator(Generic[A]):
self.__aiter__() self.__aiter__()
return await anext(self._iter) return await anext(self._iter)
async def empty(): async def empty():
return return
yield yield
class Unwrap(LoopbackIterator): class Unwrap(LoopbackIterator):
_initial: Union[ArrayTypes, Awaitable[ArrayTypes]] _initial: Union[ArrayTypes, Awaitable[ArrayTypes]]
started: bool started: bool
iterator: AsyncIterable[ArrayTypes] iterator: AsyncIterable[ArrayTypes]
def __init__(self, iterator: AsyncIterable[ArrayTypes]): def __init__(self, iterator: AsyncIterable[ArrayTypes]):
while isinstance(iterator, PassthroughTransform): while isinstance(iterator, PassthroughTransform):
iterator = iterator.handoff() iterator = iterator.handoff()
@ -74,13 +81,14 @@ class Unwrap(LoopbackIterator):
@property @property
async def concat(self): async def concat(self):
return np.concatenate if isinstance(await self.dtype, np.dtype) \ return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat
else torch.cat
class PassthroughTransform(LoopbackIterator): class PassthroughTransform(LoopbackIterator):
def handoff(self) -> AsyncIterable[ArrayTypes]: def handoff(self) -> AsyncIterable[ArrayTypes]:
raise NotImplementedError raise NotImplementedError
class BoxedIterator(PassthroughTransform): class BoxedIterator(PassthroughTransform):
def __init__(self, iterator): def __init__(self, iterator):
self.iterator = iterator self.iterator = iterator
@ -99,6 +107,7 @@ class BoxedIterator(PassthroughTransform):
if self.flag != flag: if self.flag != flag:
raise Exception("source can only be used by one iterator") raise Exception("source can only be used by one iterator")
def LookAlong(axis: int): def LookAlong(axis: int):
assert axis >= 0 assert axis >= 0
empties = (slice(None),) * axis empties = (slice(None),) * axis
@ -119,10 +128,9 @@ def LookAlong(axis: int):
return LookAlong return LookAlong
class PassthroughMap(PassthroughTransform): class PassthroughMap(PassthroughTransform):
def __init__( def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]):
self, apply: Callable[[A], ArrayTypes],
iterator: AsyncIterator[A]):
self.iterator, self.apply = iterator, apply self.iterator, self.apply = iterator, apply
def handoff(self) -> AsyncIterator[A]: def handoff(self) -> AsyncIterator[A]:
@ -132,6 +140,7 @@ class PassthroughMap(PassthroughTransform):
async for i in self.iterator: async for i in self.iterator:
yield self.apply(i) yield self.apply(i)
class Group: class Group:
def __init__(self, concat, axis=-1): def __init__(self, concat, axis=-1):
self.concat = concat self.concat = concat
@ -155,16 +164,18 @@ class Group:
if taking == amount or not exact: if taking == amount or not exact:
self.shape += amount - taking self.shape += amount - taking
self.consumed = 0 self.consumed = 0
res = self.concat([self.holding[0][start:]] + [ res = self.concat(
i.value for i in self.holding[1 : i + 1]]) [self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]]
self.holding = self.holding[i + 1:] )
self.holding = self.holding[i + 1 :]
return res return res
if i == 0: if i == 0:
return self.holding[0][start:self.consumed] return self.holding[0][start : self.consumed]
res = self.concat( res = self.concat(
[self.holding[0][start:]] + [self.holding[0][start:]]
[i.value for i in self.holding[1 : i]] + + [i.value for i in self.holding[1:i]]
[self.holding[i][:self.consumed]]) + [self.holding[i][: self.consumed]]
)
self.holding = self.holding[i:] self.holding = self.holding[i:]
return res return res
@ -175,10 +186,12 @@ class Group:
self.holding = [] self.holding = []
return res return res
class Taken: class Taken:
def take(self, *a, **kw): def take(self, *a, **kw):
raise Exception("batch queue moved") raise Exception("batch queue moved")
class Batcher(PassthroughTransform): class Batcher(PassthroughTransform):
def __init__(self, iterator, size, axis=-1, exact=False): def __init__(self, iterator, size, axis=-1, exact=False):
assert isinstance(size, int) and size > 0 assert isinstance(size, int) and size > 0
@ -192,14 +205,19 @@ class Batcher(PassthroughTransform):
return lambda tensors: f(tensors, self.axis) return lambda tensors: f(tensors, self.axis)
_iterator = None _iterator = None
async def iterator(self): async def iterator(self):
if self._iterator is None: if self._iterator is None:
self.axis = len(await self.preview.shape) + self._axis \ self.axis = (
if self._axis < 0 else self._axis len(await self.preview.shape) + self._axis
if self._axis < 0
else self._axis
)
if not hasattr(self, "group"): if not hasattr(self, "group"):
self.group = Group(await self.concat()) self.group = Group(await self.concat())
self._iterator = PassthroughMap( self._iterator = PassthroughMap(
LookAlong(self.axis), BoxedIterator(self.preview)) LookAlong(self.axis), BoxedIterator(self.preview)
)
return self._iterator return self._iterator
def handoff(self): def handoff(self):
@ -219,4 +237,3 @@ class Batcher(PassthroughTransform):
return self.group.all() return self.group.all()
raise raise
return self.group.take(self.size, self.exact) return self.group.take(self.size, self.exact)

View File

@ -1,18 +1,16 @@
import asyncio
import json
import subprocess
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine
from typing import IO, Optional, Union
import numpy as np import numpy as np
import asyncio, pathlib, subprocess, torch, json import torch
from .audio import ( from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters
SAMPLE_RATE,
N_FFT,
HOP_LENGTH,
N_FRAMES,
mel_filters,
)
from .utils import PathType, ceildiv
from .batching import Batcher from .batching import Batcher
from typing import Optional, Union, IO, Tuple, Any, Type from .utils import PathType, ceildiv
from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable
class AudioSink: class AudioSink:
def __init__(self, *, rate: int = SAMPLE_RATE, **kw): def __init__(self, *, rate: int = SAMPLE_RATE, **kw):
@ -25,11 +23,19 @@ class AudioSink:
def write(self, data): def write(self, data):
raise NotImplementedError raise NotImplementedError
class ArrayStream(AudioSink): class ArrayStream(AudioSink):
q: asyncio.Queue q: asyncio.Queue
def __init__( def __init__(
self, *, device: Optional[Union[str, torch.device]] = None, self,
batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw): *,
device: Optional[Union[str, torch.device]] = None,
batch: int = 1,
n_mels: int = 80,
capacity: int = 1_000_000,
**kw,
):
super().__init__(**kw) super().__init__(**kw)
self.q = asyncio.Queue(capacity) self.q = asyncio.Queue(capacity)
self.finished = asyncio.Event() self.finished = asyncio.Event()
@ -43,6 +49,7 @@ class ArrayStream(AudioSink):
return torch.zeros(shape, dtype=torch.float32, device=self.device) return torch.zeros(shape, dtype=torch.float32, device=self.device)
write_blockable: bool = True write_blockable: bool = True
def write(self, data: bytes) -> Optional[Coroutine]: def write(self, data: bytes) -> Optional[Coroutine]:
if self.write_blockable: if self.write_blockable:
return self.q.put(data) return self.q.put(data)
@ -51,11 +58,9 @@ class ArrayStream(AudioSink):
return None return None
def load(self, data: bytes) -> np.ndarray: def load(self, data: bytes) -> np.ndarray:
return np.frombuffer( return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
data, np.int16).flatten().astype(np.float32) / 32768.0
async def loader(self, iterator: AsyncIterable[bytes]) -> \ async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]:
AsyncIterator[np.ndarray]:
async for data in iterator: async for data in iterator:
yield self.load(data) yield self.load(data)
@ -64,7 +69,8 @@ class ArrayStream(AudioSink):
while not self.finished.is_set(): while not self.finished.is_set():
getter = asyncio.create_task(self.q.get()) getter = asyncio.create_task(self.q.get())
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
(waiter, getter), return_when=asyncio.FIRST_COMPLETED) (waiter, getter), return_when=asyncio.FIRST_COMPLETED
)
if getter in done: if getter in done:
yield getter.result() yield getter.result()
while not self.q.empty(): while not self.q.empty():
@ -78,8 +84,10 @@ class ArrayStream(AudioSink):
pass pass
loading: Optional[Batcher] = None loading: Optional[Batcher] = None
async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \
AsyncIterator[np.ndarray]: async def fft_offset(
self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[np.ndarray]:
init = self.loader(iterator) if self.loading is None else self.loading init = self.loader(iterator) if self.loading is None else self.loading
self.loading = Batcher(init, HOP_LENGTH) self.loading = Batcher(init, HOP_LENGTH)
_iterator = aiter(self.loading) _iterator = aiter(self.loading)
@ -89,7 +97,7 @@ class ArrayStream(AudioSink):
window = np.concatenate((window, await anext(_iterator))) window = np.concatenate((window, await anext(_iterator)))
except StopAsyncIteration: except StopAsyncIteration:
return return
window = np.pad(window, (N_FFT // 2, 0), 'reflect') window = np.pad(window, (N_FFT // 2, 0), "reflect")
yield window yield window
async for data in _iterator: async for data in _iterator:
yield data yield data
@ -101,8 +109,9 @@ class ArrayStream(AudioSink):
hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH
return sees[hopped:] return sees[hopped:]
async def window(self, iterator: AsyncIterable[bytes]) -> \ async def window(
AsyncIterator[torch.Tensor]: self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[torch.Tensor]:
_iterator = self.fft_offset(iterator) _iterator = self.fft_offset(iterator)
async for data in _iterator: async for data in _iterator:
_data = torch.from_numpy(data) _data = torch.from_numpy(data)
@ -121,19 +130,23 @@ class ArrayStream(AudioSink):
def dft(self, amp: torch.Tensor) -> torch.Tensor: def dft(self, amp: torch.Tensor) -> torch.Tensor:
return torch.stft( return torch.stft(
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
return_complex=True) )
# https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149 # https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149
log_spec_bound: Optional[torch.Tensor] = None log_spec_bound: Optional[torch.Tensor] = None
def transform(self, stft: torch.Tensor) -> torch.Tensor: def transform(self, stft: torch.Tensor) -> torch.Tensor:
magnitudes = stft.abs() ** 2 magnitudes = stft.abs() ** 2
mel_spec = self.filters @ magnitudes mel_spec = self.filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
# causes values to not precisely match the original # causes values to not precisely match the original
self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \ self.log_spec_bound = (
log_spec.max()
if self.log_spec_bound is None
else torch.maximum(log_spec.max(), self.log_spec_bound) else torch.maximum(log_spec.max(), self.log_spec_bound)
)
log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0) log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
@ -143,6 +156,7 @@ class ArrayStream(AudioSink):
# dft_pad: add ending content frames to match padding from a centered STFT # dft_pad: add ending content frames to match padding from a centered STFT
dft_pad: bool = False dft_pad: bool = False
def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor: def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor:
dft_pad = self.dft_pad if dft_pad is None else dft_pad dft_pad = self.dft_pad if dft_pad is None else dft_pad
if dft_pad: if dft_pad:
@ -170,34 +184,36 @@ class ArrayStream(AudioSink):
return self.runoff() return self.runoff()
staging: Optional[Batcher] = None staging: Optional[Batcher] = None
async def _push(self, sec: float, exact: bool = False) -> \
AsyncIterator[torch.Tensor]: async def _push(
self, sec: float, exact: bool = False
) -> AsyncIterator[torch.Tensor]:
batching = int(sec * SAMPLE_RATE // HOP_LENGTH) batching = int(sec * SAMPLE_RATE // HOP_LENGTH)
init = self.window(self.buffer()) if self.staging is None \ init = self.window(self.buffer()) if self.staging is None else self.staging
else self.staging
self.staging = Batcher(init, batching, exact=exact) self.staging = Batcher(init, batching, exact=exact)
async for frame in self.staging: async for frame in self.staging:
batched = batching if exact else frame.shape[-1] batched = batching if exact else frame.shape[-1]
cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0) cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0)
self.offset += cutoff self.offset += cutoff
self.spectogram = torch.cat(( self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1)
self.spectogram[:, cutoff:], frame), -1)
yield self.runoff() yield self.runoff()
reader: Optional[Awaitable] = None reader: Optional[Awaitable] = None
def start(self, **kw) -> None: def start(self, **kw) -> None:
if self.reader is None: if self.reader is None:
self.reader = asyncio.create_task(self.read(**kw)) self.reader = asyncio.create_task(self.read(**kw))
async def push(self, sec: float, exact: bool=False, **kw) -> \ async def push(
AsyncIterator[torch.Tensor]: self, sec: float, exact: bool = False, **kw
) -> AsyncIterator[torch.Tensor]:
self.start(**kw) self.start(**kw)
async for i in self._push(sec, exact): async for i in self._push(sec, exact):
yield i yield i
assert self.reader is not None assert self.reader is not None
await self.reader await self.reader
async def request(self, sec: float, exact: bool=True, **kw) -> torch.Tensor: async def request(self, sec: float, exact: bool = True, **kw) -> torch.Tensor:
try: try:
return await anext(self.push(sec, exact)) return await anext(self.push(sec, exact))
except StopAsyncIteration: except StopAsyncIteration:
@ -224,17 +240,17 @@ class ArrayStream(AudioSink):
def all_amplitudes(self, **kw) -> np.ndarray: def all_amplitudes(self, **kw) -> np.ndarray:
return asyncio.run(self.amplitudes(**kw)) return asyncio.run(self.amplitudes(**kw))
class RawAudioFile(ArrayStream): class RawAudioFile(ArrayStream):
def __init__( def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw):
self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw',
**kw):
super().__init__(**kw) super().__init__(**kw)
self.fname = fname self.fname = fname
self.period = period self.period = period
fp: Optional[IO[bytes]] = None fp: Optional[IO[bytes]] = None
async def read(self) -> None: async def read(self) -> None:
fp = open(self.fname, 'rb') if self.fp is None else self.fp fp = open(self.fname, "rb") if self.fp is None else self.fp
data = fp.read(self.period) data = fp.read(self.period)
while len(data) != 0: while len(data) != 0:
io_hold = self.write(data) io_hold = self.write(data)
@ -243,28 +259,33 @@ class RawAudioFile(ArrayStream):
data = fp.read(self.period) data = fp.read(self.period)
self.finished.set() self.finished.set()
class AudioFile(RawAudioFile): class AudioFile(RawAudioFile):
def __init__( def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw):
self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav',
**kw):
assert not subprocess.run( assert not subprocess.run(
["which", "ffmpeg"], stdout=subprocess.PIPE).returncode ["which", "ffmpeg"], stdout=subprocess.PIPE
).returncode
super().__init__(period=period or -1, fname=fname, **kw) super().__init__(period=period or -1, fname=fname, **kw)
async def read(self) -> None: async def read(self) -> None:
cmd = [ cmd = [
"ffmpeg", "ffmpeg",
"-nostdin", "-nostdin",
"-threads", "0", "-threads",
"-i", self.fname, "0",
"-f", "s16le", "-i",
"-ac", "1", self.fname,
"-acodec", "pcm_s16le", "-f",
"-ar", str(self.rate), "s16le",
"-" "-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(self.rate),
"-",
] ]
ps = subprocess.Popen( ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self.fp = ps.stdout self.fp = ps.stdout
await super().read() await super().read()
_, stderr = ps.communicate() _, stderr = ps.communicate()
@ -277,13 +298,13 @@ class AudioFile(RawAudioFile):
"ffprobe", "ffprobe",
"-hide_banner", "-hide_banner",
"-show_format", "-show_format",
"-of", "json", "-of",
"-i", self.fname, "json",
"-i",
self.fname,
] ]
ps = subprocess.Popen( ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = ps.communicate() stdout, stderr = ps.communicate()
if ps.returncode not in (None, 0): if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}") raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
return float(json.loads(stdout)['format']['duration']) return float(json.loads(stdout)["format"]["duration"])

View File

@ -1,27 +1,30 @@
import argparse import argparse
import asyncio
import os import os
import traceback import traceback
import warnings import warnings
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import ( from .audio import (
CHUNK_LENGTH,
FRAMES_PER_SECOND, FRAMES_PER_SECOND,
HOP_LENGTH, HOP_LENGTH,
N_FRAMES, N_FRAMES,
SAMPLE_RATE, SAMPLE_RATE,
CHUNK_LENGTH,
pad_or_trim, pad_or_trim,
) )
from .buffer import ArrayStream, AudioFile
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
from .utils import ( from .utils import (
PassthroughProperty,
PassthroughPropertyDefaults,
exact_div, exact_div,
format_timestamp, format_timestamp,
get_end, get_end,
@ -30,14 +33,12 @@ from .utils import (
optional_float, optional_float,
optional_int, optional_int,
str2bool, str2bool,
PassthroughProperty,
PassthroughPropertyDefaults,
) )
from .buffer import ArrayStream, AudioFile
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
@dataclass @dataclass
class LanguageHypothesis: class LanguageHypothesis:
language: Optional[str] = None language: Optional[str] = None
@ -45,15 +46,17 @@ class LanguageHypothesis:
evidence: int = 0 evidence: int = 0
last: int = -1 last: int = -1
class Transcriber(metaclass=PassthroughPropertyDefaults): class Transcriber(metaclass=PassthroughPropertyDefaults):
prefix: str = '''"'\u201c\u00bf([{-''' prefix: str = """"'\u201c\u00bf([{-"""
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001''' postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001"""
punctuation: str = prefix + postfix punctuation: str = prefix + postfix
verbose: Optional[bool] = None verbose: Optional[bool] = None
_decode_options: dict = {} _decode_options: dict = {}
decode_props: Tuple[str, ...] = ("fp16", "language", "task") decode_props: Tuple[str, ...] = ("fp16", "language", "task")
@property @property
def decode_options(self) -> dict: def decode_options(self) -> dict:
for k in self.decode_props: for k in self.decode_props:
@ -68,6 +71,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
setattr(self, k, value[k]) setattr(self, k, value[k])
dtype: torch.dtype = torch.float16 dtype: torch.dtype = torch.float16
@property @property
def fp16(self) -> bool: def fp16(self) -> bool:
return self.dtype == torch.float16 return self.dtype == torch.float16
@ -93,8 +97,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self._device = value self._device = value
if value == torch.device("cpu"): if value == torch.device("cpu"):
if torch.cuda.is_available(): if torch.cuda.is_available():
warnings.warn( warnings.warn("Performing inference on CPU when CUDA is available")
"Performing inference on CPU when CUDA is available")
self.fp16device() self.fp16device()
def fp16device(self) -> None: def fp16device(self) -> None:
@ -110,6 +113,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
prev: Optional[torch.Tensor] = None prev: Optional[torch.Tensor] = None
_latest: Optional[torch.Tensor] = None _latest: Optional[torch.Tensor] = None
@PassthroughProperty[Optional[torch.Tensor]](None).setter @PassthroughProperty[Optional[torch.Tensor]](None).setter
def latest(self, value: Optional[torch.Tensor]) -> None: def latest(self, value: Optional[torch.Tensor]) -> None:
self.prev = self._latest self.prev = self._latest
@ -117,6 +121,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
_hypothesis: LanguageHypothesis = LanguageHypothesis() _hypothesis: LanguageHypothesis = LanguageHypothesis()
_language: Optional[str] _language: Optional[str]
@PassthroughProperty[Optional[str]](None).property @PassthroughProperty[Optional[str]](None).property
def language(self) -> Optional[str]: def language(self) -> Optional[str]:
if self._language is not None: if self._language is not None:
@ -126,19 +131,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if self.verbose: if self.verbose:
print( print(
"Detecting language using up to the first 30 seconds." "Detecting language using up to the first 30 seconds."
"Use `--language` to specify the language") "Use `--language` to specify the language"
)
if self.latest is None: if self.latest is None:
return None return None
if self._seek == self._hypothesis.last: if self._seek == self._hypothesis.last:
return self._hypothesis.language return self._hypothesis.language
if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2: if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2:
mel = self.latest if self.prev is None else torch.cat( mel = (
(self.prev[:self.frame_offset], self.latest), -1) self.latest
if self.prev is None
else torch.cat((self.prev[: self.frame_offset], self.latest), -1)
)
self._language = self.detect_language(mel) self._language = self.detect_language(mel)
return self._language return self._language
self._hypothesis.last = self._seek or 0 self._hypothesis.last = self._seek or 0
self._hypothesis.since += 1 self._hypothesis.since += 1
if 2 ** self._hypothesis.evidence < self._hypothesis.since: if 2**self._hypothesis.evidence < self._hypothesis.since:
return self._hypothesis.language return self._hypothesis.language
self._hypothesis.since = 0 self._hypothesis.since = 0
guess = self.detect_language() guess = self.detect_language()
@ -152,22 +161,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]): def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
self._seek_clips = None self._seek_clips = None
if isinstance(value, str): if isinstance(value, str):
self._clip_timestamps = tuple(map(float, value.split(","))) \ self._clip_timestamps = (
if value else (0,) tuple(map(float, value.split(","))) if value else (0,)
)
else: else:
self._clip_timestamps = tuple(value) or (0,) self._clip_timestamps = tuple(value) or (0,)
_seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None _seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None
@property @property
def seek_clips(self) -> List[Tuple[int, Optional[int]]]: def seek_clips(self) -> List[Tuple[int, Optional[int]]]:
if self._seek_clips is None: if self._seek_clips is None:
seek_points = tuple( seek_points = tuple(
round(ts * FRAMES_PER_SECOND) round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps
for ts in self.clip_timestamps) + (None,) ) + (None,)
self._seek_clips = list(zip(seek_points[::2], seek_points[1::2])) self._seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
return self._seek_clips return self._seek_clips
_seek: Optional[int] _seek: Optional[int]
@PassthroughProperty[Optional[int]](None).property @PassthroughProperty[Optional[int]](None).property
def seek(self) -> Optional[int]: def seek(self) -> Optional[int]:
return self.seek_clips[0][0] if self._seek is None else self._seek return self.seek_clips[0][0] if self._seek is None else self._seek
@ -179,25 +191,30 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if value < len(clips): if value < len(clips):
self.seek = clips[value][0] self.seek = clips[value][0]
time_offset = property(lambda self: float( time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE))
self.seek * HOP_LENGTH / SAMPLE_RATE)) window_end_time = property(
window_end_time = property(lambda self: float( lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
(self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)) )
_temperature: Union[Optional[float], Tuple[float, ...]] _temperature: Union[Optional[float], Tuple[float, ...]]
@PassthroughProperty[Union[Optional[float], Tuple[float, ...]]]( @PassthroughProperty[Union[Optional[float], Tuple[float, ...]]](
(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
).setter
def temperature(self, value: Union[Optional[float], Tuple[float, ...]]): def temperature(self, value: Union[Optional[float], Tuple[float, ...]]):
self._temperature = (value,) if isinstance(value, (int, float)) else ( self._temperature = (
Transcriber._temperature if value is None else value) (value,)
if isinstance(value, (int, float))
else (Transcriber._temperature if value is None else value)
)
@PassthroughProperty("transcribe").setter @PassthroughProperty("transcribe").setter
def task(self, value: str): def task(self, value: str):
self._task = value self._task = value
if self.word_timestamps and value == "translate": if self.word_timestamps and value == "translate":
warnings.warn( warnings.warn(
"Word-level timestamps on translations may not be " "Word-level timestamps on translations may not be " "reliable."
"reliable.") )
@PassthroughProperty(False).setter @PassthroughProperty(False).setter
def word_timestamps(self, value: bool): def word_timestamps(self, value: bool):
@ -207,6 +224,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
get_tokenizer = staticmethod(get_tokenizer) get_tokenizer = staticmethod(get_tokenizer)
_tokenizer: Optional[Tokenizer] = None _tokenizer: Optional[Tokenizer] = None
_tokenizer_cache: Dict[str, Tokenizer] = {} _tokenizer_cache: Dict[str, Tokenizer] = {}
@property @property
def tokenizer(self) -> Tokenizer: def tokenizer(self) -> Tokenizer:
if self._tokenizer is None: if self._tokenizer is None:
@ -235,6 +253,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
_initial_prompt_tokens: Optional[List[int]] = None _initial_prompt_tokens: Optional[List[int]] = None
_initial_prompt_cache: Dict[Tokenizer, List[int]] = {} _initial_prompt_cache: Dict[Tokenizer, List[int]] = {}
@property @property
def initial_prompt_tokens(self) -> List[int]: def initial_prompt_tokens(self) -> List[int]:
if self._initial_prompt_tokens is None: if self._initial_prompt_tokens is None:
@ -246,9 +265,9 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return [] return []
if tokenizer not in self._initial_prompt_cache: if tokenizer not in self._initial_prompt_cache:
self._initial_prompt_cache[tokenizer] = tokenizer.encode( self._initial_prompt_cache[tokenizer] = tokenizer.encode(
" " + self.initial_prompt.strip()) " " + self.initial_prompt.strip()
self._initial_prompt_tokens = \ )
self._initial_prompt_cache[tokenizer] self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
return self._initial_prompt_cache[tokenizer] return self._initial_prompt_cache[tokenizer]
return self._initial_prompt_tokens return self._initial_prompt_tokens
@ -256,6 +275,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
last_speech_timestamp: float = 0.0 last_speech_timestamp: float = 0.0
frame_offset: int = 0 frame_offset: int = 0
all_segments: List[dict] all_segments: List[dict]
def __init__( def __init__(
self, self,
model: "Whisper", model: "Whisper",
@ -272,7 +292,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
append_punctuations: str = postfix, append_punctuations: str = postfix,
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
**decode_options): **decode_options,
):
""" """
Transcribe an audio file using Whisper Transcribe an audio file using Whisper
@ -374,14 +395,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
needs_fallback = False needs_fallback = False
if self.compression_ratio_threshold is not None and ( if self.compression_ratio_threshold is not None and (
decode_result.compression_ratio > decode_result.compression_ratio > self.compression_ratio_threshold
self.compression_ratio_threshold): ):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive
if self.logprob_threshold is not None and ( if self.logprob_threshold is not None and (
decode_result.avg_logprob < self.logprob_threshold): decode_result.avg_logprob < self.logprob_threshold
):
needs_fallback = True # average log probability is too low needs_fallback = True # average log probability is too low
if self.no_speech_threshold is not None and ( if self.no_speech_threshold is not None and (
decode_result.no_speech_prob > self.no_speech_threshold): decode_result.no_speech_prob > self.no_speech_threshold
):
needs_fallback = False # silence needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
break break
@ -389,8 +412,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return decode_result return decode_result
def new_segment( def new_segment(
self, *, start: float, end: float, tokens: torch.Tensor, self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
result: DecodingResult) -> dict: ) -> dict:
_tokens = tokens.tolist() _tokens = tokens.tolist()
text_tokens = [token for token in _tokens if token < self.tokenizer.eot] text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
return { return {
@ -422,9 +445,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
def is_segment_anomaly(self, segment: Optional[dict]) -> bool: def is_segment_anomaly(self, segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]: if segment is None or not segment["words"]:
return False return False
words = [ words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8]
w for w in segment["words"]
if w["word"] not in self.punctuation][:8]
score = sum(self.word_anomaly_score(w) for w in words) score = sum(self.word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words) return score >= 3 or score + 0.01 >= len(words)
@ -433,11 +454,15 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return next((s for s in segments if s["words"]), None) return next((s for s in segments if s["words"]), None)
def reseek( def reseek(
self, current_segments: List[dict], segment_size: int, self,
single_timestamp_ending: bool, tokens: torch.Tensor, current_segments: List[dict],
timestamp_tokens: torch.Tensor, result: DecodingResult): segment_size: int,
consecutive = torch.where( single_timestamp_ending: bool,
timestamp_tokens[:-1] & timestamp_tokens[1:])[0] tokens: torch.Tensor,
timestamp_tokens: torch.Tensor,
result: DecodingResult,
):
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1) consecutive.add_(1)
if len(consecutive) > 0: if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens # if the output contains two consecutive timestamp tokens
@ -449,17 +474,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
for current_slice in slices: for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = ( start_timestamp_pos = (
sliced_tokens[0].item() - sliced_tokens[0].item() - self.tokenizer.timestamp_begin
self.tokenizer.timestamp_begin) )
end_timestamp_pos = ( end_timestamp_pos = (
sliced_tokens[-1].item() - sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
self.tokenizer.timestamp_begin) )
current_segments.append( current_segments.append(
self.new_segment( self.new_segment(
start=self.time_offset + \ start=self.time_offset
start_timestamp_pos * self.time_precision, + start_timestamp_pos * self.time_precision,
end=self.time_offset + \ end=self.time_offset + end_timestamp_pos * self.time_precision,
end_timestamp_pos * self.time_precision,
tokens=sliced_tokens, tokens=sliced_tokens,
result=result, result=result,
) )
@ -474,31 +498,42 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
# otherwise, ignore the unfinished segment and seek to the last # otherwise, ignore the unfinished segment and seek to the last
# timestamp # timestamp
last_timestamp_pos = ( last_timestamp_pos = (
tokens[last_slice - 1].item() - tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
self.tokenizer.timestamp_begin) )
self.seek += last_timestamp_pos * self.input_stride self.seek += last_timestamp_pos * self.input_stride
else: else:
duration = segment_size * HOP_LENGTH / SAMPLE_RATE duration = segment_size * HOP_LENGTH / SAMPLE_RATE
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and \ if (
timestamps[-1].item() != self.tokenizer.timestamp_begin: len(timestamps) > 0
and timestamps[-1].item() != self.tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last # no consecutive timestamps but it has a timestamp; use the last
# one. # one.
last_timestamp_pos = \ last_timestamp_pos = (
timestamps[-1].item() - self.tokenizer.timestamp_begin timestamps[-1].item() - self.tokenizer.timestamp_begin
)
duration = last_timestamp_pos * self.time_precision duration = last_timestamp_pos * self.time_precision
current_segments.append(self.new_segment( current_segments.append(
self.new_segment(
start=self.time_offset, start=self.time_offset,
end=self.time_offset + duration, end=self.time_offset + duration,
tokens=tokens, tokens=tokens,
result=result)) result=result,
)
)
self.seek += segment_size self.seek += segment_size
def timestamp( def timestamp(
self, current_segments: List[dict], segment_size: int, self,
single_timestamp_ending: bool, mel_segment: torch.Tensor, current_segments: List[dict],
previous_seek: int, content_frames: int) -> bool: segment_size: int,
single_timestamp_ending: bool,
mel_segment: torch.Tensor,
previous_seek: int,
content_frames: int,
) -> bool:
add_word_timestamps( add_word_timestamps(
segments=current_segments, segments=current_segments,
model=self.model, model=self.model,
@ -512,8 +547,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if not single_timestamp_ending: if not single_timestamp_ending:
last_word_end = get_end(current_segments) last_word_end = get_end(current_segments)
if last_word_end is not None and \ if last_word_end is not None and last_word_end > self.time_offset:
last_word_end > self.time_offset:
self.seek = round(last_word_end * FRAMES_PER_SECOND) self.seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations # skip silence before possible hallucinations
@ -521,24 +555,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
threshold = self.hallucination_silence_threshold threshold = self.hallucination_silence_threshold
if not single_timestamp_ending: if not single_timestamp_ending:
last_word_end = get_end(current_segments) last_word_end = get_end(current_segments)
if last_word_end is not None and \ if last_word_end is not None and last_word_end > self.time_offset:
last_word_end > self.time_offset: remaining_duration = self.window_end_time - last_word_end
remaining_duration = \
self.window_end_time - last_word_end
if remaining_duration > threshold: if remaining_duration > threshold:
self.seek = round( self.seek = round(last_word_end * FRAMES_PER_SECOND)
last_word_end * FRAMES_PER_SECOND)
else: else:
self.seek = previous_seek + segment_size self.seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence # if first segment might be a hallucination, skip leading silence
first_segment = self.next_words_segment(current_segments) first_segment = self.next_words_segment(current_segments)
if first_segment is not None and self.is_segment_anomaly( if first_segment is not None and self.is_segment_anomaly(first_segment):
first_segment):
gap = first_segment["start"] - self.time_offset gap = first_segment["start"] - self.time_offset
if gap > threshold: if gap > threshold:
self.seek = previous_seek + round( self.seek = previous_seek + round(gap * FRAMES_PER_SECOND)
gap * FRAMES_PER_SECOND)
return True return True
# skip silence before any possible hallucination that is # skip silence before any possible hallucination that is
@ -550,14 +579,13 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if not segment["words"]: if not segment["words"]:
continue continue
if self.is_segment_anomaly(segment): if self.is_segment_anomaly(segment):
next_segment = self.next_words_segment( next_segment = self.next_words_segment(current_segments[si + 1 :])
current_segments[si + 1 :])
if next_segment is not None: if next_segment is not None:
hal_next_start = \ hal_next_start = next_segment["words"][0]["start"]
next_segment["words"][0]["start"]
else: else:
hal_next_start = self.time_offset + \ hal_next_start = (
segment_size * HOP_LENGTH / SAMPLE_RATE self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE
)
silence_before = ( silence_before = (
segment["start"] - hal_last_end > threshold segment["start"] - hal_last_end > threshold
or segment["start"] < threshold or segment["start"] < threshold
@ -573,8 +601,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
max(self.time_offset + 1, segment["start"]) max(self.time_offset + 1, segment["start"])
* FRAMES_PER_SECOND * FRAMES_PER_SECOND
) )
if content_duration - segment["end"] < \ if content_duration - segment["end"] < threshold:
threshold:
self.seek = content_frames self.seek = content_frames
current_segments[si:] = [] current_segments[si:] = []
break break
@ -586,19 +613,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return False return False
def __call__( def __call__(
self, mel: torch.Tensor, offset: int = 0, self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False
single_pass: bool = False) -> dict: ) -> dict:
self.latest, self.frame_offset = mel, offset self.latest, self.frame_offset = mel, offset
content_frames = mel.shape[-1] - N_FRAMES + offset content_frames = mel.shape[-1] - N_FRAMES + offset
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
# NOTE: This loop is obscurely flattened to make the diff readable. # NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop. # A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips: # for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end # while seek < seek_clip_end
while self.clip_idx < len(self.seek_clips): while self.clip_idx < len(self.seek_clips):
seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx] seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx]
seek_clip_end = content_frames if seek_clip_end is None else \ seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end
seek_clip_end
if self.seek < seek_clip_start: if self.seek < seek_clip_start:
self.seek = seek_clip_start self.seek = seek_clip_start
if self.seek >= seek_clip_end: if self.seek >= seek_clip_end:
@ -607,22 +632,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self.clip_idx += 1 self.clip_idx += 1
continue continue
segment_size = min( segment_size = min(
N_FRAMES, content_frames - self.seek, N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek
seek_clip_end - self.seek) )
mel_segment = mel[ mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset]
:, self.seek - offset : self.seek + segment_size - offset] mel_segment = (
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to( pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype)
self.device).to(self.dtype) )
self.decode_options["prompt"] = \ self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :]
self.all_tokens[self.prompt_reset_since:]
result: DecodingResult = self.decode_with_fallback(mel_segment) result: DecodingResult = self.decode_with_fallback(mel_segment)
if self.no_speech_threshold is not None: if self.no_speech_threshold is not None:
# no voice activity check # no voice activity check
should_skip = result.no_speech_prob > self.no_speech_threshold should_skip = result.no_speech_prob > self.no_speech_threshold
if self.logprob_threshold is not None and \ if (
result.avg_logprob > self.logprob_threshold: self.logprob_threshold is not None
and result.avg_logprob > self.logprob_threshold
):
# don't skip if the logprob is high enough, despite the # don't skip if the logprob is high enough, despite the
# no_speech_prob # no_speech_prob
should_skip = False should_skip = False
@ -636,19 +662,27 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
current_segments: List[dict] = [] current_segments: List[dict] = []
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge( timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
self.tokenizer.timestamp_begin) single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
single_timestamp_ending = (
timestamp_tokens[-2:].tolist() == [False, True])
self.reseek( self.reseek(
current_segments, segment_size, single_timestamp_ending, current_segments,
tokens, timestamp_tokens, result) segment_size,
single_timestamp_ending,
tokens,
timestamp_tokens,
result,
)
if self.word_timestamps: if self.word_timestamps:
if self.timestamp( if self.timestamp(
current_segments, segment_size, single_timestamp_ending, current_segments,
mel_segment, previous_seek, content_frames): segment_size,
single_timestamp_ending,
mel_segment,
previous_seek,
content_frames,
):
continue continue
if self.verbose: if self.verbose:
@ -657,24 +691,28 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
text = segment["text"] text = segment["text"]
line = ( line = (
f"[{format_timestamp(start)} --> " f"[{format_timestamp(start)} --> "
f"{format_timestamp(end)}] {text}") f"{format_timestamp(end)}] {text}"
)
print(make_safe(line)) print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it # if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments): for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or \ if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"].strip() == "":
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
self.all_segments.extend([ self.all_segments.extend(
[
{"id": i, **segment} {"id": i, **segment}
for i, segment in enumerate( for i, segment in enumerate(
current_segments, start=len(self.all_segments))]) current_segments, start=len(self.all_segments)
self.all_tokens.extend([ )
token for segment in current_segments ]
for token in segment["tokens"]]) )
self.all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not self.condition_on_previous_text or result.temperature > 0.5: if not self.condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
@ -686,9 +724,12 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
break break
self.result = dict( self.result = dict(
segments=self.all_segments, language=self.language, segments=self.all_segments,
language=self.language,
text=self.tokenizer.decode( text=self.tokenizer.decode(
self.all_tokens[len(self.initial_prompt_tokens):])) self.all_tokens[len(self.initial_prompt_tokens) :]
),
)
self.latest = None self.latest = None
return self.result return self.result
@ -699,16 +740,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
while len(self.all_segments) > 0 and ( while len(self.all_segments) > 0 and (
self.all_segments[-1]["start"] >= seconds self.all_segments[-1]["start"] >= seconds
if len(self.all_segments) == 1 else if len(self.all_segments) == 1
self.all_segments[-2]["end"] > seconds): else self.all_segments[-2]["end"] > seconds
):
rewriting = self.all_segments.pop() rewriting = self.all_segments.pop()
processing += len(rewriting["tokens"]) processing += len(rewriting["tokens"])
self.all_tokens = self.all_tokens[:len(self.all_tokens) - processing] self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing]
if len(self.all_segments) > 0 and ( if len(self.all_segments) > 0 and (
self.all_segments[-1]["start"] < seconds and self.all_segments[-1]["start"] < seconds
self.all_segments[-1]["end"] >= seconds): and self.all_segments[-1]["end"] >= seconds
self.seek = round( ):
self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH) self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH)
else: else:
self.seek = offset self.seek = offset
@ -728,12 +770,16 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
class MinimalTranscriber(Transcriber): class MinimalTranscriber(Transcriber):
exact: bool = True exact: bool = True
chlen: float = CHUNK_LENGTH chlen: float = CHUNK_LENGTH
async def process(self, stream: ArrayStream, **kw) -> dict: async def process(self, stream: ArrayStream, **kw) -> dict:
data = await stream.request(self.chlen, self.exact) data = await stream.request(self.chlen, self.exact)
while data.shape[-1] > 0: while data.shape[-1] > 0:
self(data, stream.offset, True) self(data, stream.offset, True)
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \ t = (
/ FRAMES_PER_SECOND + CHUNK_LENGTH self.chlen
- (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND
+ CHUNK_LENGTH
)
data = await stream.request(t, self.exact) data = await stream.request(t, self.exact)
return self.result return self.result
@ -755,12 +801,16 @@ class ProgressTranscriber(MinimalTranscriber):
@PassthroughProperty(None).property @PassthroughProperty(None).property
def pbar(self): def pbar(self):
if self._pbar is None: if self._pbar is None:
n = self.latest.shape[-1] if self.duration is None \ n = (
self.latest.shape[-1]
if self.duration is None
else -int(self.duration * -FRAMES_PER_SECOND) else -int(self.duration * -FRAMES_PER_SECOND)
)
# show the progress bar when verbose is False # show the progress bar when verbose is False
# (if True, transcribed text will be printed) # (if True, transcribed text will be printed)
self._pbar = tqdm.tqdm( self._pbar = tqdm.tqdm(
total=n, unit="frames", disable=self.verbose is not False) total=n, unit="frames", disable=self.verbose is not False
)
self._pbar.__enter__() self._pbar.__enter__()
return self._pbar return self._pbar
@ -784,10 +834,7 @@ class ProgressTranscriber(MinimalTranscriber):
return await self.process(stream, **kw) return await self.process(stream, **kw)
def transcribe( def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw):
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
**kw):
""" """
Transcribe an audio file using Whisper Transcribe an audio file using Whisper

View File

@ -5,16 +5,7 @@ import re
import sys import sys
import time import time
import zlib import zlib
from typing import ( from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union
Callable,
List,
Optional,
TextIO,
Union,
TypeVar,
Generic,
Any
)
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@ -86,15 +77,17 @@ def format_timestamp(
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
) )
def hms(sec: float) -> str: def hms(sec: float) -> str:
trim = sec < 3600 trim = sec < 3600
h = "" if trim else str(int(sec) // 3600) + ":" h = "" if trim else str(int(sec) // 3600) + ":"
m_fill = " " if trim else "0" m_fill = " " if trim else "0"
m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":" m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":"
s = str(int(sec) % 60).rjust(2, '0') + "." s = str(int(sec) % 60).rjust(2, "0") + "."
c = str(round((sec % 1) * 100)).rjust(2, '0') c = str(round((sec % 1) * 100)).rjust(2, "0")
return h + m + s + c return h + m + s + c
def tod(seconds: float) -> str: def tod(seconds: float) -> str:
return time.strftime("%H:%M:%S", time.localtime(seconds)) return time.strftime("%H:%M:%S", time.localtime(seconds))
@ -349,28 +342,34 @@ def get_writer(
T = TypeVar("T") T = TypeVar("T")
# boilerplate for property with _{name} storage and passthrough getter/setter # boilerplate for property with _{name} storage and passthrough getter/setter
class PassthroughProperty(Generic[T]): class PassthroughProperty(Generic[T]):
def __init__(self, default: T): def __init__(self, default: T):
self.value = default self.value = default
f: Optional[Callable[[Any, T], None]] = None f: Optional[Callable[[Any, T], None]] = None
def setter(self, f: Callable[[Any, T], None]): def setter(self, f: Callable[[Any, T], None]):
self.f = f self.f = f
return self return self
g: Optional[property] = None g: Optional[property] = None
def property(self, g: Callable[[Any], T]): def property(self, g: Callable[[Any], T]):
self.g = property(g) self.g = property(g)
return self return self
class PassthroughPropertyDefaults(type): class PassthroughPropertyDefaults(type):
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs):
def closure(f, v): def closure(f, v):
def prop(self): def prop(self):
return getattr(self, v) return getattr(self, v)
def setter(self, value): def setter(self, value):
setattr(self, v, value) setattr(self, v, value)
prop.__name__ = setter.__name__ = f prop.__name__ = setter.__name__ = f
return property(prop), setter return property(prop), setter
@ -383,4 +382,3 @@ class PassthroughPropertyDefaults(type):
getter, setter = closure(k, private) getter, setter = closure(k, private)
updates[k] = (v.g or getter).setter(v.f or setter) updates[k] = (v.g or getter).setter(v.f or setter)
return super().__new__(cls, clsname, bases, {**attrs, **updates}) return super().__new__(cls, clsname, bases, {**attrs, **updates})