pre-commit formatting

This commit is contained in:
Kent Slaney 2024-07-14 19:07:06 -06:00
parent e0704ddeba
commit 0621ed8094
4 changed files with 328 additions and 245 deletions

View File

@ -1,15 +1,19 @@
import torch
import numpy as np
from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Generic, TypeVar, Union
import numpy as np
import torch
A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor])
class ArrayWrapper(Generic[A]):
pass
ArrayTypes = Union[A, ArrayWrapper[A]]
class LoopbackIterator(Generic[A]):
async def iter(self):
raise NotImplementedError
@ -23,14 +27,17 @@ class LoopbackIterator(Generic[A]):
self.__aiter__()
return await anext(self._iter)
async def empty():
return
yield
class Unwrap(LoopbackIterator):
_initial: Union[ArrayTypes, Awaitable[ArrayTypes]]
started: bool
iterator: AsyncIterable[ArrayTypes]
def __init__(self, iterator: AsyncIterable[ArrayTypes]):
while isinstance(iterator, PassthroughTransform):
iterator = iterator.handoff()
@ -74,13 +81,14 @@ class Unwrap(LoopbackIterator):
@property
async def concat(self):
return np.concatenate if isinstance(await self.dtype, np.dtype) \
else torch.cat
return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat
class PassthroughTransform(LoopbackIterator):
def handoff(self) -> AsyncIterable[ArrayTypes]:
raise NotImplementedError
class BoxedIterator(PassthroughTransform):
def __init__(self, iterator):
self.iterator = iterator
@ -99,6 +107,7 @@ class BoxedIterator(PassthroughTransform):
if self.flag != flag:
raise Exception("source can only be used by one iterator")
def LookAlong(axis: int):
assert axis >= 0
empties = (slice(None),) * axis
@ -119,10 +128,9 @@ def LookAlong(axis: int):
return LookAlong
class PassthroughMap(PassthroughTransform):
def __init__(
self, apply: Callable[[A], ArrayTypes],
iterator: AsyncIterator[A]):
def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]):
self.iterator, self.apply = iterator, apply
def handoff(self) -> AsyncIterator[A]:
@ -132,6 +140,7 @@ class PassthroughMap(PassthroughTransform):
async for i in self.iterator:
yield self.apply(i)
class Group:
def __init__(self, concat, axis=-1):
self.concat = concat
@ -155,16 +164,18 @@ class Group:
if taking == amount or not exact:
self.shape += amount - taking
self.consumed = 0
res = self.concat([self.holding[0][start:]] + [
i.value for i in self.holding[1 : i + 1]])
res = self.concat(
[self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]]
)
self.holding = self.holding[i + 1 :]
return res
if i == 0:
return self.holding[0][start : self.consumed]
res = self.concat(
[self.holding[0][start:]] +
[i.value for i in self.holding[1 : i]] +
[self.holding[i][:self.consumed]])
[self.holding[0][start:]]
+ [i.value for i in self.holding[1:i]]
+ [self.holding[i][: self.consumed]]
)
self.holding = self.holding[i:]
return res
@ -175,10 +186,12 @@ class Group:
self.holding = []
return res
class Taken:
def take(self, *a, **kw):
raise Exception("batch queue moved")
class Batcher(PassthroughTransform):
def __init__(self, iterator, size, axis=-1, exact=False):
assert isinstance(size, int) and size > 0
@ -192,14 +205,19 @@ class Batcher(PassthroughTransform):
return lambda tensors: f(tensors, self.axis)
_iterator = None
async def iterator(self):
if self._iterator is None:
self.axis = len(await self.preview.shape) + self._axis \
if self._axis < 0 else self._axis
self.axis = (
len(await self.preview.shape) + self._axis
if self._axis < 0
else self._axis
)
if not hasattr(self, "group"):
self.group = Group(await self.concat())
self._iterator = PassthroughMap(
LookAlong(self.axis), BoxedIterator(self.preview))
LookAlong(self.axis), BoxedIterator(self.preview)
)
return self._iterator
def handoff(self):
@ -219,4 +237,3 @@ class Batcher(PassthroughTransform):
return self.group.all()
raise
return self.group.take(self.size, self.exact)

View File

@ -1,18 +1,16 @@
import asyncio
import json
import subprocess
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine
from typing import IO, Optional, Union
import numpy as np
import asyncio, pathlib, subprocess, torch, json
import torch
from .audio import (
SAMPLE_RATE,
N_FFT,
HOP_LENGTH,
N_FRAMES,
mel_filters,
)
from .utils import PathType, ceildiv
from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters
from .batching import Batcher
from typing import Optional, Union, IO, Tuple, Any, Type
from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable
from .utils import PathType, ceildiv
class AudioSink:
def __init__(self, *, rate: int = SAMPLE_RATE, **kw):
@ -25,11 +23,19 @@ class AudioSink:
def write(self, data):
raise NotImplementedError
class ArrayStream(AudioSink):
q: asyncio.Queue
def __init__(
self, *, device: Optional[Union[str, torch.device]] = None,
batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw):
self,
*,
device: Optional[Union[str, torch.device]] = None,
batch: int = 1,
n_mels: int = 80,
capacity: int = 1_000_000,
**kw,
):
super().__init__(**kw)
self.q = asyncio.Queue(capacity)
self.finished = asyncio.Event()
@ -43,6 +49,7 @@ class ArrayStream(AudioSink):
return torch.zeros(shape, dtype=torch.float32, device=self.device)
write_blockable: bool = True
def write(self, data: bytes) -> Optional[Coroutine]:
if self.write_blockable:
return self.q.put(data)
@ -51,11 +58,9 @@ class ArrayStream(AudioSink):
return None
def load(self, data: bytes) -> np.ndarray:
return np.frombuffer(
data, np.int16).flatten().astype(np.float32) / 32768.0
return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
async def loader(self, iterator: AsyncIterable[bytes]) -> \
AsyncIterator[np.ndarray]:
async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]:
async for data in iterator:
yield self.load(data)
@ -64,7 +69,8 @@ class ArrayStream(AudioSink):
while not self.finished.is_set():
getter = asyncio.create_task(self.q.get())
done, pending = await asyncio.wait(
(waiter, getter), return_when=asyncio.FIRST_COMPLETED)
(waiter, getter), return_when=asyncio.FIRST_COMPLETED
)
if getter in done:
yield getter.result()
while not self.q.empty():
@ -78,8 +84,10 @@ class ArrayStream(AudioSink):
pass
loading: Optional[Batcher] = None
async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \
AsyncIterator[np.ndarray]:
async def fft_offset(
self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[np.ndarray]:
init = self.loader(iterator) if self.loading is None else self.loading
self.loading = Batcher(init, HOP_LENGTH)
_iterator = aiter(self.loading)
@ -89,7 +97,7 @@ class ArrayStream(AudioSink):
window = np.concatenate((window, await anext(_iterator)))
except StopAsyncIteration:
return
window = np.pad(window, (N_FFT // 2, 0), 'reflect')
window = np.pad(window, (N_FFT // 2, 0), "reflect")
yield window
async for data in _iterator:
yield data
@ -101,8 +109,9 @@ class ArrayStream(AudioSink):
hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH
return sees[hopped:]
async def window(self, iterator: AsyncIterable[bytes]) -> \
AsyncIterator[torch.Tensor]:
async def window(
self, iterator: AsyncIterable[bytes]
) -> AsyncIterator[torch.Tensor]:
_iterator = self.fft_offset(iterator)
async for data in _iterator:
_data = torch.from_numpy(data)
@ -121,19 +130,23 @@ class ArrayStream(AudioSink):
def dft(self, amp: torch.Tensor) -> torch.Tensor:
return torch.stft(
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False,
return_complex=True)
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
)
# https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149
log_spec_bound: Optional[torch.Tensor] = None
def transform(self, stft: torch.Tensor) -> torch.Tensor:
magnitudes = stft.abs() ** 2
mel_spec = self.filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
# causes values to not precisely match the original
self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \
self.log_spec_bound = (
log_spec.max()
if self.log_spec_bound is None
else torch.maximum(log_spec.max(), self.log_spec_bound)
)
log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
@ -143,6 +156,7 @@ class ArrayStream(AudioSink):
# dft_pad: add ending content frames to match padding from a centered STFT
dft_pad: bool = False
def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor:
dft_pad = self.dft_pad if dft_pad is None else dft_pad
if dft_pad:
@ -170,27 +184,29 @@ class ArrayStream(AudioSink):
return self.runoff()
staging: Optional[Batcher] = None
async def _push(self, sec: float, exact: bool = False) -> \
AsyncIterator[torch.Tensor]:
async def _push(
self, sec: float, exact: bool = False
) -> AsyncIterator[torch.Tensor]:
batching = int(sec * SAMPLE_RATE // HOP_LENGTH)
init = self.window(self.buffer()) if self.staging is None \
else self.staging
init = self.window(self.buffer()) if self.staging is None else self.staging
self.staging = Batcher(init, batching, exact=exact)
async for frame in self.staging:
batched = batching if exact else frame.shape[-1]
cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0)
self.offset += cutoff
self.spectogram = torch.cat((
self.spectogram[:, cutoff:], frame), -1)
self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1)
yield self.runoff()
reader: Optional[Awaitable] = None
def start(self, **kw) -> None:
if self.reader is None:
self.reader = asyncio.create_task(self.read(**kw))
async def push(self, sec: float, exact: bool=False, **kw) -> \
AsyncIterator[torch.Tensor]:
async def push(
self, sec: float, exact: bool = False, **kw
) -> AsyncIterator[torch.Tensor]:
self.start(**kw)
async for i in self._push(sec, exact):
yield i
@ -224,17 +240,17 @@ class ArrayStream(AudioSink):
def all_amplitudes(self, **kw) -> np.ndarray:
return asyncio.run(self.amplitudes(**kw))
class RawAudioFile(ArrayStream):
def __init__(
self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw',
**kw):
def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw):
super().__init__(**kw)
self.fname = fname
self.period = period
fp: Optional[IO[bytes]] = None
async def read(self) -> None:
fp = open(self.fname, 'rb') if self.fp is None else self.fp
fp = open(self.fname, "rb") if self.fp is None else self.fp
data = fp.read(self.period)
while len(data) != 0:
io_hold = self.write(data)
@ -243,28 +259,33 @@ class RawAudioFile(ArrayStream):
data = fp.read(self.period)
self.finished.set()
class AudioFile(RawAudioFile):
def __init__(
self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav',
**kw):
def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw):
assert not subprocess.run(
["which", "ffmpeg"], stdout=subprocess.PIPE).returncode
["which", "ffmpeg"], stdout=subprocess.PIPE
).returncode
super().__init__(period=period or -1, fname=fname, **kw)
async def read(self) -> None:
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", self.fname,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(self.rate),
"-"
"-threads",
"0",
"-i",
self.fname,
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(self.rate),
"-",
]
ps = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self.fp = ps.stdout
await super().read()
_, stderr = ps.communicate()
@ -277,13 +298,13 @@ class AudioFile(RawAudioFile):
"ffprobe",
"-hide_banner",
"-show_format",
"-of", "json",
"-i", self.fname,
"-of",
"json",
"-i",
self.fname,
]
ps = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = ps.communicate()
if ps.returncode not in (None, 0):
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
return float(json.loads(stdout)['format']['duration'])
return float(json.loads(stdout)["format"]["duration"])

View File

@ -1,27 +1,30 @@
import argparse
import asyncio
import os
import traceback
import warnings
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import (
CHUNK_LENGTH,
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
CHUNK_LENGTH,
pad_or_trim,
)
from .buffer import ArrayStream, AudioFile
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
from .utils import (
PassthroughProperty,
PassthroughPropertyDefaults,
exact_div,
format_timestamp,
get_end,
@ -30,14 +33,12 @@ from .utils import (
optional_float,
optional_int,
str2bool,
PassthroughProperty,
PassthroughPropertyDefaults,
)
from .buffer import ArrayStream, AudioFile
if TYPE_CHECKING:
from .model import Whisper
@dataclass
class LanguageHypothesis:
language: Optional[str] = None
@ -45,15 +46,17 @@ class LanguageHypothesis:
evidence: int = 0
last: int = -1
class Transcriber(metaclass=PassthroughPropertyDefaults):
prefix: str = '''"'\u201c\u00bf([{-'''
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
prefix: str = """"'\u201c\u00bf([{-"""
postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001"""
punctuation: str = prefix + postfix
verbose: Optional[bool] = None
_decode_options: dict = {}
decode_props: Tuple[str, ...] = ("fp16", "language", "task")
@property
def decode_options(self) -> dict:
for k in self.decode_props:
@ -68,6 +71,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
setattr(self, k, value[k])
dtype: torch.dtype = torch.float16
@property
def fp16(self) -> bool:
return self.dtype == torch.float16
@ -93,8 +97,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self._device = value
if value == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn(
"Performing inference on CPU when CUDA is available")
warnings.warn("Performing inference on CPU when CUDA is available")
self.fp16device()
def fp16device(self) -> None:
@ -110,6 +113,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
prev: Optional[torch.Tensor] = None
_latest: Optional[torch.Tensor] = None
@PassthroughProperty[Optional[torch.Tensor]](None).setter
def latest(self, value: Optional[torch.Tensor]) -> None:
self.prev = self._latest
@ -117,6 +121,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
_hypothesis: LanguageHypothesis = LanguageHypothesis()
_language: Optional[str]
@PassthroughProperty[Optional[str]](None).property
def language(self) -> Optional[str]:
if self._language is not None:
@ -126,14 +131,18 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if self.verbose:
print(
"Detecting language using up to the first 30 seconds."
"Use `--language` to specify the language")
"Use `--language` to specify the language"
)
if self.latest is None:
return None
if self._seek == self._hypothesis.last:
return self._hypothesis.language
if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2:
mel = self.latest if self.prev is None else torch.cat(
(self.prev[:self.frame_offset], self.latest), -1)
mel = (
self.latest
if self.prev is None
else torch.cat((self.prev[: self.frame_offset], self.latest), -1)
)
self._language = self.detect_language(mel)
return self._language
self._hypothesis.last = self._seek or 0
@ -152,22 +161,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
self._seek_clips = None
if isinstance(value, str):
self._clip_timestamps = tuple(map(float, value.split(","))) \
if value else (0,)
self._clip_timestamps = (
tuple(map(float, value.split(","))) if value else (0,)
)
else:
self._clip_timestamps = tuple(value) or (0,)
_seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None
@property
def seek_clips(self) -> List[Tuple[int, Optional[int]]]:
if self._seek_clips is None:
seek_points = tuple(
round(ts * FRAMES_PER_SECOND)
for ts in self.clip_timestamps) + (None,)
round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps
) + (None,)
self._seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
return self._seek_clips
_seek: Optional[int]
@PassthroughProperty[Optional[int]](None).property
def seek(self) -> Optional[int]:
return self.seek_clips[0][0] if self._seek is None else self._seek
@ -179,25 +191,30 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if value < len(clips):
self.seek = clips[value][0]
time_offset = property(lambda self: float(
self.seek * HOP_LENGTH / SAMPLE_RATE))
window_end_time = property(lambda self: float(
(self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE))
time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE))
window_end_time = property(
lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
)
_temperature: Union[Optional[float], Tuple[float, ...]]
@PassthroughProperty[Union[Optional[float], Tuple[float, ...]]](
(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter
(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
).setter
def temperature(self, value: Union[Optional[float], Tuple[float, ...]]):
self._temperature = (value,) if isinstance(value, (int, float)) else (
Transcriber._temperature if value is None else value)
self._temperature = (
(value,)
if isinstance(value, (int, float))
else (Transcriber._temperature if value is None else value)
)
@PassthroughProperty("transcribe").setter
def task(self, value: str):
self._task = value
if self.word_timestamps and value == "translate":
warnings.warn(
"Word-level timestamps on translations may not be "
"reliable.")
"Word-level timestamps on translations may not be " "reliable."
)
@PassthroughProperty(False).setter
def word_timestamps(self, value: bool):
@ -207,6 +224,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
get_tokenizer = staticmethod(get_tokenizer)
_tokenizer: Optional[Tokenizer] = None
_tokenizer_cache: Dict[str, Tokenizer] = {}
@property
def tokenizer(self) -> Tokenizer:
if self._tokenizer is None:
@ -235,6 +253,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
_initial_prompt_tokens: Optional[List[int]] = None
_initial_prompt_cache: Dict[Tokenizer, List[int]] = {}
@property
def initial_prompt_tokens(self) -> List[int]:
if self._initial_prompt_tokens is None:
@ -246,9 +265,9 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return []
if tokenizer not in self._initial_prompt_cache:
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
" " + self.initial_prompt.strip())
self._initial_prompt_tokens = \
self._initial_prompt_cache[tokenizer]
" " + self.initial_prompt.strip()
)
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
return self._initial_prompt_cache[tokenizer]
return self._initial_prompt_tokens
@ -256,6 +275,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
last_speech_timestamp: float = 0.0
frame_offset: int = 0
all_segments: List[dict]
def __init__(
self,
model: "Whisper",
@ -272,7 +292,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
append_punctuations: str = postfix,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options):
**decode_options,
):
"""
Transcribe an audio file using Whisper
@ -374,14 +395,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
needs_fallback = False
if self.compression_ratio_threshold is not None and (
decode_result.compression_ratio >
self.compression_ratio_threshold):
decode_result.compression_ratio > self.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if self.logprob_threshold is not None and (
decode_result.avg_logprob < self.logprob_threshold):
decode_result.avg_logprob < self.logprob_threshold
):
needs_fallback = True # average log probability is too low
if self.no_speech_threshold is not None and (
decode_result.no_speech_prob > self.no_speech_threshold):
decode_result.no_speech_prob > self.no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
@ -389,8 +412,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return decode_result
def new_segment(
self, *, start: float, end: float, tokens: torch.Tensor,
result: DecodingResult) -> dict:
self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
) -> dict:
_tokens = tokens.tolist()
text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
return {
@ -422,9 +445,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
def is_segment_anomaly(self, segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [
w for w in segment["words"]
if w["word"] not in self.punctuation][:8]
words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8]
score = sum(self.word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
@ -433,11 +454,15 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return next((s for s in segments if s["words"]), None)
def reseek(
self, current_segments: List[dict], segment_size: int,
single_timestamp_ending: bool, tokens: torch.Tensor,
timestamp_tokens: torch.Tensor, result: DecodingResult):
consecutive = torch.where(
timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
self,
current_segments: List[dict],
segment_size: int,
single_timestamp_ending: bool,
tokens: torch.Tensor,
timestamp_tokens: torch.Tensor,
result: DecodingResult,
):
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
@ -449,17 +474,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() -
self.tokenizer.timestamp_begin)
sliced_tokens[0].item() - self.tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() -
self.tokenizer.timestamp_begin)
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
)
current_segments.append(
self.new_segment(
start=self.time_offset + \
start_timestamp_pos * self.time_precision,
end=self.time_offset + \
end_timestamp_pos * self.time_precision,
start=self.time_offset
+ start_timestamp_pos * self.time_precision,
end=self.time_offset + end_timestamp_pos * self.time_precision,
tokens=sliced_tokens,
result=result,
)
@ -474,31 +498,42 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
# otherwise, ignore the unfinished segment and seek to the last
# timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() -
self.tokenizer.timestamp_begin)
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
)
self.seek += last_timestamp_pos * self.input_stride
else:
duration = segment_size * HOP_LENGTH / SAMPLE_RATE
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and \
timestamps[-1].item() != self.tokenizer.timestamp_begin:
if (
len(timestamps) > 0
and timestamps[-1].item() != self.tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last
# one.
last_timestamp_pos = \
last_timestamp_pos = (
timestamps[-1].item() - self.tokenizer.timestamp_begin
)
duration = last_timestamp_pos * self.time_precision
current_segments.append(self.new_segment(
current_segments.append(
self.new_segment(
start=self.time_offset,
end=self.time_offset + duration,
tokens=tokens,
result=result))
result=result,
)
)
self.seek += segment_size
def timestamp(
self, current_segments: List[dict], segment_size: int,
single_timestamp_ending: bool, mel_segment: torch.Tensor,
previous_seek: int, content_frames: int) -> bool:
self,
current_segments: List[dict],
segment_size: int,
single_timestamp_ending: bool,
mel_segment: torch.Tensor,
previous_seek: int,
content_frames: int,
) -> bool:
add_word_timestamps(
segments=current_segments,
model=self.model,
@ -512,8 +547,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and \
last_word_end > self.time_offset:
if last_word_end is not None and last_word_end > self.time_offset:
self.seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
@ -521,24 +555,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
threshold = self.hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and \
last_word_end > self.time_offset:
remaining_duration = \
self.window_end_time - last_word_end
if last_word_end is not None and last_word_end > self.time_offset:
remaining_duration = self.window_end_time - last_word_end
if remaining_duration > threshold:
self.seek = round(
last_word_end * FRAMES_PER_SECOND)
self.seek = round(last_word_end * FRAMES_PER_SECOND)
else:
self.seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = self.next_words_segment(current_segments)
if first_segment is not None and self.is_segment_anomaly(
first_segment):
if first_segment is not None and self.is_segment_anomaly(first_segment):
gap = first_segment["start"] - self.time_offset
if gap > threshold:
self.seek = previous_seek + round(
gap * FRAMES_PER_SECOND)
self.seek = previous_seek + round(gap * FRAMES_PER_SECOND)
return True
# skip silence before any possible hallucination that is
@ -550,14 +579,13 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
if not segment["words"]:
continue
if self.is_segment_anomaly(segment):
next_segment = self.next_words_segment(
current_segments[si + 1 :])
next_segment = self.next_words_segment(current_segments[si + 1 :])
if next_segment is not None:
hal_next_start = \
next_segment["words"][0]["start"]
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = self.time_offset + \
segment_size * HOP_LENGTH / SAMPLE_RATE
hal_next_start = (
self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE
)
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
@ -573,8 +601,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
max(self.time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < \
threshold:
if content_duration - segment["end"] < threshold:
self.seek = content_frames
current_segments[si:] = []
break
@ -586,19 +613,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
return False
def __call__(
self, mel: torch.Tensor, offset: int = 0,
single_pass: bool = False) -> dict:
self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False
) -> dict:
self.latest, self.frame_offset = mel, offset
content_frames = mel.shape[-1] - N_FRAMES + offset
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while self.clip_idx < len(self.seek_clips):
seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx]
seek_clip_end = content_frames if seek_clip_end is None else \
seek_clip_end
seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end
if self.seek < seek_clip_start:
self.seek = seek_clip_start
if self.seek >= seek_clip_end:
@ -607,22 +632,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
self.clip_idx += 1
continue
segment_size = min(
N_FRAMES, content_frames - self.seek,
seek_clip_end - self.seek)
mel_segment = mel[
:, self.seek - offset : self.seek + segment_size - offset]
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(
self.device).to(self.dtype)
N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek
)
mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset]
mel_segment = (
pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype)
)
self.decode_options["prompt"] = \
self.all_tokens[self.prompt_reset_since:]
self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :]
result: DecodingResult = self.decode_with_fallback(mel_segment)
if self.no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > self.no_speech_threshold
if self.logprob_threshold is not None and \
result.avg_logprob > self.logprob_threshold:
if (
self.logprob_threshold is not None
and result.avg_logprob > self.logprob_threshold
):
# don't skip if the logprob is high enough, despite the
# no_speech_prob
should_skip = False
@ -636,19 +662,27 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
current_segments: List[dict] = []
tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(
self.tokenizer.timestamp_begin)
single_timestamp_ending = (
timestamp_tokens[-2:].tolist() == [False, True])
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
self.reseek(
current_segments, segment_size, single_timestamp_ending,
tokens, timestamp_tokens, result)
current_segments,
segment_size,
single_timestamp_ending,
tokens,
timestamp_tokens,
result,
)
if self.word_timestamps:
if self.timestamp(
current_segments, segment_size, single_timestamp_ending,
mel_segment, previous_seek, content_frames):
current_segments,
segment_size,
single_timestamp_ending,
mel_segment,
previous_seek,
content_frames,
):
continue
if self.verbose:
@ -657,24 +691,28 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
text = segment["text"]
line = (
f"[{format_timestamp(start)} --> "
f"{format_timestamp(end)}] {text}")
f"{format_timestamp(end)}] {text}"
)
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or \
segment["text"].strip() == "":
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
self.all_segments.extend([
self.all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(self.all_segments))])
self.all_tokens.extend([
token for segment in current_segments
for token in segment["tokens"]])
current_segments, start=len(self.all_segments)
)
]
)
self.all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not self.condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
@ -686,9 +724,12 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
break
self.result = dict(
segments=self.all_segments, language=self.language,
segments=self.all_segments,
language=self.language,
text=self.tokenizer.decode(
self.all_tokens[len(self.initial_prompt_tokens):]))
self.all_tokens[len(self.initial_prompt_tokens) :]
),
)
self.latest = None
return self.result
@ -699,16 +740,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
while len(self.all_segments) > 0 and (
self.all_segments[-1]["start"] >= seconds
if len(self.all_segments) == 1 else
self.all_segments[-2]["end"] > seconds):
if len(self.all_segments) == 1
else self.all_segments[-2]["end"] > seconds
):
rewriting = self.all_segments.pop()
processing += len(rewriting["tokens"])
self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing]
if len(self.all_segments) > 0 and (
self.all_segments[-1]["start"] < seconds and
self.all_segments[-1]["end"] >= seconds):
self.seek = round(
self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH)
self.all_segments[-1]["start"] < seconds
and self.all_segments[-1]["end"] >= seconds
):
self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH)
else:
self.seek = offset
@ -728,12 +770,16 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
class MinimalTranscriber(Transcriber):
exact: bool = True
chlen: float = CHUNK_LENGTH
async def process(self, stream: ArrayStream, **kw) -> dict:
data = await stream.request(self.chlen, self.exact)
while data.shape[-1] > 0:
self(data, stream.offset, True)
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \
/ FRAMES_PER_SECOND + CHUNK_LENGTH
t = (
self.chlen
- (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND
+ CHUNK_LENGTH
)
data = await stream.request(t, self.exact)
return self.result
@ -755,12 +801,16 @@ class ProgressTranscriber(MinimalTranscriber):
@PassthroughProperty(None).property
def pbar(self):
if self._pbar is None:
n = self.latest.shape[-1] if self.duration is None \
n = (
self.latest.shape[-1]
if self.duration is None
else -int(self.duration * -FRAMES_PER_SECOND)
)
# show the progress bar when verbose is False
# (if True, transcribed text will be printed)
self._pbar = tqdm.tqdm(
total=n, unit="frames", disable=self.verbose is not False)
total=n, unit="frames", disable=self.verbose is not False
)
self._pbar.__enter__()
return self._pbar
@ -784,10 +834,7 @@ class ProgressTranscriber(MinimalTranscriber):
return await self.process(stream, **kw)
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
**kw):
def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw):
"""
Transcribe an audio file using Whisper

View File

@ -5,16 +5,7 @@ import re
import sys
import time
import zlib
from typing import (
Callable,
List,
Optional,
TextIO,
Union,
TypeVar,
Generic,
Any
)
from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union
system_encoding = sys.getdefaultencoding()
@ -86,15 +77,17 @@ def format_timestamp(
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def hms(sec: float) -> str:
trim = sec < 3600
h = "" if trim else str(int(sec) // 3600) + ":"
m_fill = " " if trim else "0"
m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":"
s = str(int(sec) % 60).rjust(2, '0') + "."
c = str(round((sec % 1) * 100)).rjust(2, '0')
s = str(int(sec) % 60).rjust(2, "0") + "."
c = str(round((sec % 1) * 100)).rjust(2, "0")
return h + m + s + c
def tod(seconds: float) -> str:
return time.strftime("%H:%M:%S", time.localtime(seconds))
@ -349,28 +342,34 @@ def get_writer(
T = TypeVar("T")
# boilerplate for property with _{name} storage and passthrough getter/setter
class PassthroughProperty(Generic[T]):
def __init__(self, default: T):
self.value = default
f: Optional[Callable[[Any, T], None]] = None
def setter(self, f: Callable[[Any, T], None]):
self.f = f
return self
g: Optional[property] = None
def property(self, g: Callable[[Any], T]):
self.g = property(g)
return self
class PassthroughPropertyDefaults(type):
def __new__(cls, clsname, bases, attrs):
def closure(f, v):
def prop(self):
return getattr(self, v)
def setter(self, value):
setattr(self, v, value)
prop.__name__ = setter.__name__ = f
return property(prop), setter
@ -383,4 +382,3 @@ class PassthroughPropertyDefaults(type):
getter, setter = closure(k, private)
updates[k] = (v.g or getter).setter(v.f or setter)
return super().__new__(cls, clsname, bases, {**attrs, **updates})