mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
pre-commit formatting
This commit is contained in:
parent
e0704ddeba
commit
0621ed8094
@ -1,15 +1,19 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections.abc import Callable, AsyncIterable, AsyncIterator, Awaitable
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
||||
from typing import Generic, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
A = TypeVar("A", bound=Union[np.ndarray, torch.Tensor])
|
||||
|
||||
|
||||
class ArrayWrapper(Generic[A]):
|
||||
pass
|
||||
|
||||
|
||||
ArrayTypes = Union[A, ArrayWrapper[A]]
|
||||
|
||||
|
||||
class LoopbackIterator(Generic[A]):
|
||||
async def iter(self):
|
||||
raise NotImplementedError
|
||||
@ -23,14 +27,17 @@ class LoopbackIterator(Generic[A]):
|
||||
self.__aiter__()
|
||||
return await anext(self._iter)
|
||||
|
||||
|
||||
async def empty():
|
||||
return
|
||||
yield
|
||||
|
||||
|
||||
class Unwrap(LoopbackIterator):
|
||||
_initial: Union[ArrayTypes, Awaitable[ArrayTypes]]
|
||||
started: bool
|
||||
iterator: AsyncIterable[ArrayTypes]
|
||||
|
||||
def __init__(self, iterator: AsyncIterable[ArrayTypes]):
|
||||
while isinstance(iterator, PassthroughTransform):
|
||||
iterator = iterator.handoff()
|
||||
@ -74,13 +81,14 @@ class Unwrap(LoopbackIterator):
|
||||
|
||||
@property
|
||||
async def concat(self):
|
||||
return np.concatenate if isinstance(await self.dtype, np.dtype) \
|
||||
else torch.cat
|
||||
return np.concatenate if isinstance(await self.dtype, np.dtype) else torch.cat
|
||||
|
||||
|
||||
class PassthroughTransform(LoopbackIterator):
|
||||
def handoff(self) -> AsyncIterable[ArrayTypes]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BoxedIterator(PassthroughTransform):
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iterator
|
||||
@ -99,6 +107,7 @@ class BoxedIterator(PassthroughTransform):
|
||||
if self.flag != flag:
|
||||
raise Exception("source can only be used by one iterator")
|
||||
|
||||
|
||||
def LookAlong(axis: int):
|
||||
assert axis >= 0
|
||||
empties = (slice(None),) * axis
|
||||
@ -119,10 +128,9 @@ def LookAlong(axis: int):
|
||||
|
||||
return LookAlong
|
||||
|
||||
|
||||
class PassthroughMap(PassthroughTransform):
|
||||
def __init__(
|
||||
self, apply: Callable[[A], ArrayTypes],
|
||||
iterator: AsyncIterator[A]):
|
||||
def __init__(self, apply: Callable[[A], ArrayTypes], iterator: AsyncIterator[A]):
|
||||
self.iterator, self.apply = iterator, apply
|
||||
|
||||
def handoff(self) -> AsyncIterator[A]:
|
||||
@ -132,6 +140,7 @@ class PassthroughMap(PassthroughTransform):
|
||||
async for i in self.iterator:
|
||||
yield self.apply(i)
|
||||
|
||||
|
||||
class Group:
|
||||
def __init__(self, concat, axis=-1):
|
||||
self.concat = concat
|
||||
@ -155,16 +164,18 @@ class Group:
|
||||
if taking == amount or not exact:
|
||||
self.shape += amount - taking
|
||||
self.consumed = 0
|
||||
res = self.concat([self.holding[0][start:]] + [
|
||||
i.value for i in self.holding[1 : i + 1]])
|
||||
res = self.concat(
|
||||
[self.holding[0][start:]] + [i.value for i in self.holding[1 : i + 1]]
|
||||
)
|
||||
self.holding = self.holding[i + 1 :]
|
||||
return res
|
||||
if i == 0:
|
||||
return self.holding[0][start : self.consumed]
|
||||
res = self.concat(
|
||||
[self.holding[0][start:]] +
|
||||
[i.value for i in self.holding[1 : i]] +
|
||||
[self.holding[i][:self.consumed]])
|
||||
[self.holding[0][start:]]
|
||||
+ [i.value for i in self.holding[1:i]]
|
||||
+ [self.holding[i][: self.consumed]]
|
||||
)
|
||||
self.holding = self.holding[i:]
|
||||
return res
|
||||
|
||||
@ -175,10 +186,12 @@ class Group:
|
||||
self.holding = []
|
||||
return res
|
||||
|
||||
|
||||
class Taken:
|
||||
def take(self, *a, **kw):
|
||||
raise Exception("batch queue moved")
|
||||
|
||||
|
||||
class Batcher(PassthroughTransform):
|
||||
def __init__(self, iterator, size, axis=-1, exact=False):
|
||||
assert isinstance(size, int) and size > 0
|
||||
@ -192,14 +205,19 @@ class Batcher(PassthroughTransform):
|
||||
return lambda tensors: f(tensors, self.axis)
|
||||
|
||||
_iterator = None
|
||||
|
||||
async def iterator(self):
|
||||
if self._iterator is None:
|
||||
self.axis = len(await self.preview.shape) + self._axis \
|
||||
if self._axis < 0 else self._axis
|
||||
self.axis = (
|
||||
len(await self.preview.shape) + self._axis
|
||||
if self._axis < 0
|
||||
else self._axis
|
||||
)
|
||||
if not hasattr(self, "group"):
|
||||
self.group = Group(await self.concat())
|
||||
self._iterator = PassthroughMap(
|
||||
LookAlong(self.axis), BoxedIterator(self.preview))
|
||||
LookAlong(self.axis), BoxedIterator(self.preview)
|
||||
)
|
||||
return self._iterator
|
||||
|
||||
def handoff(self):
|
||||
@ -219,4 +237,3 @@ class Batcher(PassthroughTransform):
|
||||
return self.group.all()
|
||||
raise
|
||||
return self.group.take(self.size, self.exact)
|
||||
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Coroutine
|
||||
from typing import IO, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import asyncio, pathlib, subprocess, torch, json
|
||||
import torch
|
||||
|
||||
from .audio import (
|
||||
SAMPLE_RATE,
|
||||
N_FFT,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
mel_filters,
|
||||
)
|
||||
|
||||
from .utils import PathType, ceildiv
|
||||
from .audio import HOP_LENGTH, N_FFT, N_FRAMES, SAMPLE_RATE, mel_filters
|
||||
from .batching import Batcher
|
||||
from typing import Optional, Union, IO, Tuple, Any, Type
|
||||
from collections.abc import Coroutine, AsyncIterable, AsyncIterator, Awaitable
|
||||
from .utils import PathType, ceildiv
|
||||
|
||||
|
||||
class AudioSink:
|
||||
def __init__(self, *, rate: int = SAMPLE_RATE, **kw):
|
||||
@ -25,11 +23,19 @@ class AudioSink:
|
||||
def write(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ArrayStream(AudioSink):
|
||||
q: asyncio.Queue
|
||||
|
||||
def __init__(
|
||||
self, *, device: Optional[Union[str, torch.device]] = None,
|
||||
batch: int = 1, n_mels: int = 80, capacity: int = 1_000_000, **kw):
|
||||
self,
|
||||
*,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
batch: int = 1,
|
||||
n_mels: int = 80,
|
||||
capacity: int = 1_000_000,
|
||||
**kw,
|
||||
):
|
||||
super().__init__(**kw)
|
||||
self.q = asyncio.Queue(capacity)
|
||||
self.finished = asyncio.Event()
|
||||
@ -43,6 +49,7 @@ class ArrayStream(AudioSink):
|
||||
return torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
|
||||
write_blockable: bool = True
|
||||
|
||||
def write(self, data: bytes) -> Optional[Coroutine]:
|
||||
if self.write_blockable:
|
||||
return self.q.put(data)
|
||||
@ -51,11 +58,9 @@ class ArrayStream(AudioSink):
|
||||
return None
|
||||
|
||||
def load(self, data: bytes) -> np.ndarray:
|
||||
return np.frombuffer(
|
||||
data, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
return np.frombuffer(data, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
async def loader(self, iterator: AsyncIterable[bytes]) -> \
|
||||
AsyncIterator[np.ndarray]:
|
||||
async def loader(self, iterator: AsyncIterable[bytes]) -> AsyncIterator[np.ndarray]:
|
||||
async for data in iterator:
|
||||
yield self.load(data)
|
||||
|
||||
@ -64,7 +69,8 @@ class ArrayStream(AudioSink):
|
||||
while not self.finished.is_set():
|
||||
getter = asyncio.create_task(self.q.get())
|
||||
done, pending = await asyncio.wait(
|
||||
(waiter, getter), return_when=asyncio.FIRST_COMPLETED)
|
||||
(waiter, getter), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if getter in done:
|
||||
yield getter.result()
|
||||
while not self.q.empty():
|
||||
@ -78,8 +84,10 @@ class ArrayStream(AudioSink):
|
||||
pass
|
||||
|
||||
loading: Optional[Batcher] = None
|
||||
async def fft_offset(self, iterator: AsyncIterable[bytes]) -> \
|
||||
AsyncIterator[np.ndarray]:
|
||||
|
||||
async def fft_offset(
|
||||
self, iterator: AsyncIterable[bytes]
|
||||
) -> AsyncIterator[np.ndarray]:
|
||||
init = self.loader(iterator) if self.loading is None else self.loading
|
||||
self.loading = Batcher(init, HOP_LENGTH)
|
||||
_iterator = aiter(self.loading)
|
||||
@ -89,7 +97,7 @@ class ArrayStream(AudioSink):
|
||||
window = np.concatenate((window, await anext(_iterator)))
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
window = np.pad(window, (N_FFT // 2, 0), 'reflect')
|
||||
window = np.pad(window, (N_FFT // 2, 0), "reflect")
|
||||
yield window
|
||||
async for data in _iterator:
|
||||
yield data
|
||||
@ -101,8 +109,9 @@ class ArrayStream(AudioSink):
|
||||
hopped = ((sees.shape[0] - N_FFT) // HOP_LENGTH + 1) * HOP_LENGTH
|
||||
return sees[hopped:]
|
||||
|
||||
async def window(self, iterator: AsyncIterable[bytes]) -> \
|
||||
AsyncIterator[torch.Tensor]:
|
||||
async def window(
|
||||
self, iterator: AsyncIterable[bytes]
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
_iterator = self.fft_offset(iterator)
|
||||
async for data in _iterator:
|
||||
_data = torch.from_numpy(data)
|
||||
@ -121,19 +130,23 @@ class ArrayStream(AudioSink):
|
||||
|
||||
def dft(self, amp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stft(
|
||||
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False,
|
||||
return_complex=True)
|
||||
amp, N_FFT, HOP_LENGTH, window=self.hann, center=False, return_complex=True
|
||||
)
|
||||
|
||||
# https://github.com/openai/whisper/blob/c5d4256/whisper/audio.py#L149
|
||||
log_spec_bound: Optional[torch.Tensor] = None
|
||||
|
||||
def transform(self, stft: torch.Tensor) -> torch.Tensor:
|
||||
magnitudes = stft.abs() ** 2
|
||||
mel_spec = self.filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
# causes values to not precisely match the original
|
||||
self.log_spec_bound = log_spec.max() if self.log_spec_bound is None \
|
||||
self.log_spec_bound = (
|
||||
log_spec.max()
|
||||
if self.log_spec_bound is None
|
||||
else torch.maximum(log_spec.max(), self.log_spec_bound)
|
||||
)
|
||||
log_spec = torch.maximum(log_spec, self.log_spec_bound - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
@ -143,6 +156,7 @@ class ArrayStream(AudioSink):
|
||||
|
||||
# dft_pad: add ending content frames to match padding from a centered STFT
|
||||
dft_pad: bool = False
|
||||
|
||||
def runoff(self, dft_pad: Optional[bool] = None) -> torch.Tensor:
|
||||
dft_pad = self.dft_pad if dft_pad is None else dft_pad
|
||||
if dft_pad:
|
||||
@ -170,27 +184,29 @@ class ArrayStream(AudioSink):
|
||||
return self.runoff()
|
||||
|
||||
staging: Optional[Batcher] = None
|
||||
async def _push(self, sec: float, exact: bool = False) -> \
|
||||
AsyncIterator[torch.Tensor]:
|
||||
|
||||
async def _push(
|
||||
self, sec: float, exact: bool = False
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
batching = int(sec * SAMPLE_RATE // HOP_LENGTH)
|
||||
init = self.window(self.buffer()) if self.staging is None \
|
||||
else self.staging
|
||||
init = self.window(self.buffer()) if self.staging is None else self.staging
|
||||
self.staging = Batcher(init, batching, exact=exact)
|
||||
async for frame in self.staging:
|
||||
batched = batching if exact else frame.shape[-1]
|
||||
cutoff = max(self.spectogram.shape[-1] + batched - N_FRAMES, 0)
|
||||
self.offset += cutoff
|
||||
self.spectogram = torch.cat((
|
||||
self.spectogram[:, cutoff:], frame), -1)
|
||||
self.spectogram = torch.cat((self.spectogram[:, cutoff:], frame), -1)
|
||||
yield self.runoff()
|
||||
|
||||
reader: Optional[Awaitable] = None
|
||||
|
||||
def start(self, **kw) -> None:
|
||||
if self.reader is None:
|
||||
self.reader = asyncio.create_task(self.read(**kw))
|
||||
|
||||
async def push(self, sec: float, exact: bool=False, **kw) -> \
|
||||
AsyncIterator[torch.Tensor]:
|
||||
async def push(
|
||||
self, sec: float, exact: bool = False, **kw
|
||||
) -> AsyncIterator[torch.Tensor]:
|
||||
self.start(**kw)
|
||||
async for i in self._push(sec, exact):
|
||||
yield i
|
||||
@ -224,17 +240,17 @@ class ArrayStream(AudioSink):
|
||||
def all_amplitudes(self, **kw) -> np.ndarray:
|
||||
return asyncio.run(self.amplitudes(**kw))
|
||||
|
||||
|
||||
class RawAudioFile(ArrayStream):
|
||||
def __init__(
|
||||
self, *, period: int = HOP_LENGTH, fname: PathType = 'out.raw',
|
||||
**kw):
|
||||
def __init__(self, *, period: int = HOP_LENGTH, fname: PathType = "out.raw", **kw):
|
||||
super().__init__(**kw)
|
||||
self.fname = fname
|
||||
self.period = period
|
||||
|
||||
fp: Optional[IO[bytes]] = None
|
||||
|
||||
async def read(self) -> None:
|
||||
fp = open(self.fname, 'rb') if self.fp is None else self.fp
|
||||
fp = open(self.fname, "rb") if self.fp is None else self.fp
|
||||
data = fp.read(self.period)
|
||||
while len(data) != 0:
|
||||
io_hold = self.write(data)
|
||||
@ -243,28 +259,33 @@ class RawAudioFile(ArrayStream):
|
||||
data = fp.read(self.period)
|
||||
self.finished.set()
|
||||
|
||||
|
||||
class AudioFile(RawAudioFile):
|
||||
def __init__(
|
||||
self, *, period: int = SAMPLE_RATE, fname: PathType = 'out.wav',
|
||||
**kw):
|
||||
def __init__(self, *, period: int = SAMPLE_RATE, fname: PathType = "out.wav", **kw):
|
||||
assert not subprocess.run(
|
||||
["which", "ffmpeg"], stdout=subprocess.PIPE).returncode
|
||||
["which", "ffmpeg"], stdout=subprocess.PIPE
|
||||
).returncode
|
||||
super().__init__(period=period or -1, fname=fname, **kw)
|
||||
|
||||
async def read(self) -> None:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", self.fname,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(self.rate),
|
||||
"-"
|
||||
"-threads",
|
||||
"0",
|
||||
"-i",
|
||||
self.fname,
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(self.rate),
|
||||
"-",
|
||||
]
|
||||
ps = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
self.fp = ps.stdout
|
||||
await super().read()
|
||||
_, stderr = ps.communicate()
|
||||
@ -277,13 +298,13 @@ class AudioFile(RawAudioFile):
|
||||
"ffprobe",
|
||||
"-hide_banner",
|
||||
"-show_format",
|
||||
"-of", "json",
|
||||
"-i", self.fname,
|
||||
"-of",
|
||||
"json",
|
||||
"-i",
|
||||
self.fname,
|
||||
]
|
||||
ps = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
ps = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = ps.communicate()
|
||||
if ps.returncode not in (None, 0):
|
||||
raise RuntimeError(f"Failed to load audio: {stderr.decode()}")
|
||||
return float(json.loads(stdout)['format']['duration'])
|
||||
|
||||
return float(json.loads(stdout)["format"]["duration"])
|
||||
|
||||
@ -1,27 +1,30 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import (
|
||||
CHUNK_LENGTH,
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
SAMPLE_RATE,
|
||||
CHUNK_LENGTH,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .buffer import ArrayStream, AudioFile
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer, Tokenizer
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer
|
||||
from .utils import (
|
||||
PassthroughProperty,
|
||||
PassthroughPropertyDefaults,
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_end,
|
||||
@ -30,14 +33,12 @@ from .utils import (
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
PassthroughProperty,
|
||||
PassthroughPropertyDefaults,
|
||||
)
|
||||
from .buffer import ArrayStream, AudioFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageHypothesis:
|
||||
language: Optional[str] = None
|
||||
@ -45,15 +46,17 @@ class LanguageHypothesis:
|
||||
evidence: int = 0
|
||||
last: int = -1
|
||||
|
||||
|
||||
class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
prefix: str = '''"'\u201c\u00bf([{-'''
|
||||
postfix: str = '''"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001'''
|
||||
prefix: str = """"'\u201c\u00bf([{-"""
|
||||
postfix: str = """"'.\u3002,\uff0c!\uff01?\uff1f:\uff1a\u201d)]}\u3001"""
|
||||
punctuation: str = prefix + postfix
|
||||
|
||||
verbose: Optional[bool] = None
|
||||
|
||||
_decode_options: dict = {}
|
||||
decode_props: Tuple[str, ...] = ("fp16", "language", "task")
|
||||
|
||||
@property
|
||||
def decode_options(self) -> dict:
|
||||
for k in self.decode_props:
|
||||
@ -68,6 +71,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
setattr(self, k, value[k])
|
||||
|
||||
dtype: torch.dtype = torch.float16
|
||||
|
||||
@property
|
||||
def fp16(self) -> bool:
|
||||
return self.dtype == torch.float16
|
||||
@ -93,8 +97,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
self._device = value
|
||||
if value == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"Performing inference on CPU when CUDA is available")
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
self.fp16device()
|
||||
|
||||
def fp16device(self) -> None:
|
||||
@ -110,6 +113,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
|
||||
prev: Optional[torch.Tensor] = None
|
||||
_latest: Optional[torch.Tensor] = None
|
||||
|
||||
@PassthroughProperty[Optional[torch.Tensor]](None).setter
|
||||
def latest(self, value: Optional[torch.Tensor]) -> None:
|
||||
self.prev = self._latest
|
||||
@ -117,6 +121,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
|
||||
_hypothesis: LanguageHypothesis = LanguageHypothesis()
|
||||
_language: Optional[str]
|
||||
|
||||
@PassthroughProperty[Optional[str]](None).property
|
||||
def language(self) -> Optional[str]:
|
||||
if self._language is not None:
|
||||
@ -126,14 +131,18 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
if self.verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds."
|
||||
"Use `--language` to specify the language")
|
||||
"Use `--language` to specify the language"
|
||||
)
|
||||
if self.latest is None:
|
||||
return None
|
||||
if self._seek == self._hypothesis.last:
|
||||
return self._hypothesis.language
|
||||
if self.frame_offset > 0 or self.latest.shape[-1] == N_FRAMES * 2:
|
||||
mel = self.latest if self.prev is None else torch.cat(
|
||||
(self.prev[:self.frame_offset], self.latest), -1)
|
||||
mel = (
|
||||
self.latest
|
||||
if self.prev is None
|
||||
else torch.cat((self.prev[: self.frame_offset], self.latest), -1)
|
||||
)
|
||||
self._language = self.detect_language(mel)
|
||||
return self._language
|
||||
self._hypothesis.last = self._seek or 0
|
||||
@ -152,22 +161,25 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
def clip_timestamps(self, value: Union[str, List[float], Tuple[float]]):
|
||||
self._seek_clips = None
|
||||
if isinstance(value, str):
|
||||
self._clip_timestamps = tuple(map(float, value.split(","))) \
|
||||
if value else (0,)
|
||||
self._clip_timestamps = (
|
||||
tuple(map(float, value.split(","))) if value else (0,)
|
||||
)
|
||||
else:
|
||||
self._clip_timestamps = tuple(value) or (0,)
|
||||
|
||||
_seek_clips: Optional[List[Tuple[int, Optional[int]]]] = None
|
||||
|
||||
@property
|
||||
def seek_clips(self) -> List[Tuple[int, Optional[int]]]:
|
||||
if self._seek_clips is None:
|
||||
seek_points = tuple(
|
||||
round(ts * FRAMES_PER_SECOND)
|
||||
for ts in self.clip_timestamps) + (None,)
|
||||
round(ts * FRAMES_PER_SECOND) for ts in self.clip_timestamps
|
||||
) + (None,)
|
||||
self._seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
|
||||
return self._seek_clips
|
||||
|
||||
_seek: Optional[int]
|
||||
|
||||
@PassthroughProperty[Optional[int]](None).property
|
||||
def seek(self) -> Optional[int]:
|
||||
return self.seek_clips[0][0] if self._seek is None else self._seek
|
||||
@ -179,25 +191,30 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
if value < len(clips):
|
||||
self.seek = clips[value][0]
|
||||
|
||||
time_offset = property(lambda self: float(
|
||||
self.seek * HOP_LENGTH / SAMPLE_RATE))
|
||||
window_end_time = property(lambda self: float(
|
||||
(self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE))
|
||||
time_offset = property(lambda self: float(self.seek * HOP_LENGTH / SAMPLE_RATE))
|
||||
window_end_time = property(
|
||||
lambda self: float((self.seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||
)
|
||||
|
||||
_temperature: Union[Optional[float], Tuple[float, ...]]
|
||||
|
||||
@PassthroughProperty[Union[Optional[float], Tuple[float, ...]]](
|
||||
(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).setter
|
||||
(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
).setter
|
||||
def temperature(self, value: Union[Optional[float], Tuple[float, ...]]):
|
||||
self._temperature = (value,) if isinstance(value, (int, float)) else (
|
||||
Transcriber._temperature if value is None else value)
|
||||
self._temperature = (
|
||||
(value,)
|
||||
if isinstance(value, (int, float))
|
||||
else (Transcriber._temperature if value is None else value)
|
||||
)
|
||||
|
||||
@PassthroughProperty("transcribe").setter
|
||||
def task(self, value: str):
|
||||
self._task = value
|
||||
if self.word_timestamps and value == "translate":
|
||||
warnings.warn(
|
||||
"Word-level timestamps on translations may not be "
|
||||
"reliable.")
|
||||
"Word-level timestamps on translations may not be " "reliable."
|
||||
)
|
||||
|
||||
@PassthroughProperty(False).setter
|
||||
def word_timestamps(self, value: bool):
|
||||
@ -207,6 +224,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
get_tokenizer = staticmethod(get_tokenizer)
|
||||
_tokenizer: Optional[Tokenizer] = None
|
||||
_tokenizer_cache: Dict[str, Tokenizer] = {}
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if self._tokenizer is None:
|
||||
@ -235,6 +253,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
|
||||
_initial_prompt_tokens: Optional[List[int]] = None
|
||||
_initial_prompt_cache: Dict[Tokenizer, List[int]] = {}
|
||||
|
||||
@property
|
||||
def initial_prompt_tokens(self) -> List[int]:
|
||||
if self._initial_prompt_tokens is None:
|
||||
@ -246,9 +265,9 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
return []
|
||||
if tokenizer not in self._initial_prompt_cache:
|
||||
self._initial_prompt_cache[tokenizer] = tokenizer.encode(
|
||||
" " + self.initial_prompt.strip())
|
||||
self._initial_prompt_tokens = \
|
||||
self._initial_prompt_cache[tokenizer]
|
||||
" " + self.initial_prompt.strip()
|
||||
)
|
||||
self._initial_prompt_tokens = self._initial_prompt_cache[tokenizer]
|
||||
return self._initial_prompt_cache[tokenizer]
|
||||
return self._initial_prompt_tokens
|
||||
|
||||
@ -256,6 +275,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
last_speech_timestamp: float = 0.0
|
||||
frame_offset: int = 0
|
||||
all_segments: List[dict]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "Whisper",
|
||||
@ -272,7 +292,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
append_punctuations: str = postfix,
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options):
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
@ -374,14 +395,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
|
||||
needs_fallback = False
|
||||
if self.compression_ratio_threshold is not None and (
|
||||
decode_result.compression_ratio >
|
||||
self.compression_ratio_threshold):
|
||||
decode_result.compression_ratio > self.compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if self.logprob_threshold is not None and (
|
||||
decode_result.avg_logprob < self.logprob_threshold):
|
||||
decode_result.avg_logprob < self.logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
if self.no_speech_threshold is not None and (
|
||||
decode_result.no_speech_prob > self.no_speech_threshold):
|
||||
decode_result.no_speech_prob > self.no_speech_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
@ -389,8 +412,8 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
return decode_result
|
||||
|
||||
def new_segment(
|
||||
self, *, start: float, end: float, tokens: torch.Tensor,
|
||||
result: DecodingResult) -> dict:
|
||||
self, *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
) -> dict:
|
||||
_tokens = tokens.tolist()
|
||||
text_tokens = [token for token in _tokens if token < self.tokenizer.eot]
|
||||
return {
|
||||
@ -422,9 +445,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
def is_segment_anomaly(self, segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [
|
||||
w for w in segment["words"]
|
||||
if w["word"] not in self.punctuation][:8]
|
||||
words = [w for w in segment["words"] if w["word"] not in self.punctuation][:8]
|
||||
score = sum(self.word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
@ -433,11 +454,15 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
def reseek(
|
||||
self, current_segments: List[dict], segment_size: int,
|
||||
single_timestamp_ending: bool, tokens: torch.Tensor,
|
||||
timestamp_tokens: torch.Tensor, result: DecodingResult):
|
||||
consecutive = torch.where(
|
||||
timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
self,
|
||||
current_segments: List[dict],
|
||||
segment_size: int,
|
||||
single_timestamp_ending: bool,
|
||||
tokens: torch.Tensor,
|
||||
timestamp_tokens: torch.Tensor,
|
||||
result: DecodingResult,
|
||||
):
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
@ -449,17 +474,16 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() -
|
||||
self.tokenizer.timestamp_begin)
|
||||
sliced_tokens[0].item() - self.tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() -
|
||||
self.tokenizer.timestamp_begin)
|
||||
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
|
||||
)
|
||||
current_segments.append(
|
||||
self.new_segment(
|
||||
start=self.time_offset + \
|
||||
start_timestamp_pos * self.time_precision,
|
||||
end=self.time_offset + \
|
||||
end_timestamp_pos * self.time_precision,
|
||||
start=self.time_offset
|
||||
+ start_timestamp_pos * self.time_precision,
|
||||
end=self.time_offset + end_timestamp_pos * self.time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
@ -474,31 +498,42 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
# otherwise, ignore the unfinished segment and seek to the last
|
||||
# timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() -
|
||||
self.tokenizer.timestamp_begin)
|
||||
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
|
||||
)
|
||||
self.seek += last_timestamp_pos * self.input_stride
|
||||
else:
|
||||
duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and \
|
||||
timestamps[-1].item() != self.tokenizer.timestamp_begin:
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != self.tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last
|
||||
# one.
|
||||
last_timestamp_pos = \
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - self.tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * self.time_precision
|
||||
|
||||
current_segments.append(self.new_segment(
|
||||
current_segments.append(
|
||||
self.new_segment(
|
||||
start=self.time_offset,
|
||||
end=self.time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result))
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
self.seek += segment_size
|
||||
|
||||
def timestamp(
|
||||
self, current_segments: List[dict], segment_size: int,
|
||||
single_timestamp_ending: bool, mel_segment: torch.Tensor,
|
||||
previous_seek: int, content_frames: int) -> bool:
|
||||
self,
|
||||
current_segments: List[dict],
|
||||
segment_size: int,
|
||||
single_timestamp_ending: bool,
|
||||
mel_segment: torch.Tensor,
|
||||
previous_seek: int,
|
||||
content_frames: int,
|
||||
) -> bool:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=self.model,
|
||||
@ -512,8 +547,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and \
|
||||
last_word_end > self.time_offset:
|
||||
if last_word_end is not None and last_word_end > self.time_offset:
|
||||
self.seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
|
||||
# skip silence before possible hallucinations
|
||||
@ -521,24 +555,19 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
threshold = self.hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and \
|
||||
last_word_end > self.time_offset:
|
||||
remaining_duration = \
|
||||
self.window_end_time - last_word_end
|
||||
if last_word_end is not None and last_word_end > self.time_offset:
|
||||
remaining_duration = self.window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
self.seek = round(
|
||||
last_word_end * FRAMES_PER_SECOND)
|
||||
self.seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
else:
|
||||
self.seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = self.next_words_segment(current_segments)
|
||||
if first_segment is not None and self.is_segment_anomaly(
|
||||
first_segment):
|
||||
if first_segment is not None and self.is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - self.time_offset
|
||||
if gap > threshold:
|
||||
self.seek = previous_seek + round(
|
||||
gap * FRAMES_PER_SECOND)
|
||||
self.seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||
return True
|
||||
|
||||
# skip silence before any possible hallucination that is
|
||||
@ -550,14 +579,13 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if self.is_segment_anomaly(segment):
|
||||
next_segment = self.next_words_segment(
|
||||
current_segments[si + 1 :])
|
||||
next_segment = self.next_words_segment(current_segments[si + 1 :])
|
||||
if next_segment is not None:
|
||||
hal_next_start = \
|
||||
next_segment["words"][0]["start"]
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = self.time_offset + \
|
||||
segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
hal_next_start = (
|
||||
self.time_offset + segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
)
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
@ -573,8 +601,7 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
max(self.time_offset + 1, segment["start"])
|
||||
* FRAMES_PER_SECOND
|
||||
)
|
||||
if content_duration - segment["end"] < \
|
||||
threshold:
|
||||
if content_duration - segment["end"] < threshold:
|
||||
self.seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
@ -586,19 +613,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
return False
|
||||
|
||||
def __call__(
|
||||
self, mel: torch.Tensor, offset: int = 0,
|
||||
single_pass: bool = False) -> dict:
|
||||
self, mel: torch.Tensor, offset: int = 0, single_pass: bool = False
|
||||
) -> dict:
|
||||
self.latest, self.frame_offset = mel, offset
|
||||
content_frames = mel.shape[-1] - N_FRAMES + offset
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while self.clip_idx < len(self.seek_clips):
|
||||
seek_clip_start, seek_clip_end = self.seek_clips[self.clip_idx]
|
||||
seek_clip_end = content_frames if seek_clip_end is None else \
|
||||
seek_clip_end
|
||||
seek_clip_end = content_frames if seek_clip_end is None else seek_clip_end
|
||||
if self.seek < seek_clip_start:
|
||||
self.seek = seek_clip_start
|
||||
if self.seek >= seek_clip_end:
|
||||
@ -607,22 +632,23 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
self.clip_idx += 1
|
||||
continue
|
||||
segment_size = min(
|
||||
N_FRAMES, content_frames - self.seek,
|
||||
seek_clip_end - self.seek)
|
||||
mel_segment = mel[
|
||||
:, self.seek - offset : self.seek + segment_size - offset]
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(
|
||||
self.device).to(self.dtype)
|
||||
N_FRAMES, content_frames - self.seek, seek_clip_end - self.seek
|
||||
)
|
||||
mel_segment = mel[:, self.seek - offset : self.seek + segment_size - offset]
|
||||
mel_segment = (
|
||||
pad_or_trim(mel_segment, N_FRAMES).to(self.device).to(self.dtype)
|
||||
)
|
||||
|
||||
self.decode_options["prompt"] = \
|
||||
self.all_tokens[self.prompt_reset_since:]
|
||||
self.decode_options["prompt"] = self.all_tokens[self.prompt_reset_since :]
|
||||
result: DecodingResult = self.decode_with_fallback(mel_segment)
|
||||
|
||||
if self.no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > self.no_speech_threshold
|
||||
if self.logprob_threshold is not None and \
|
||||
result.avg_logprob > self.logprob_threshold:
|
||||
if (
|
||||
self.logprob_threshold is not None
|
||||
and result.avg_logprob > self.logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the
|
||||
# no_speech_prob
|
||||
should_skip = False
|
||||
@ -636,19 +662,27 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
current_segments: List[dict] = []
|
||||
|
||||
tokens = torch.tensor(result.tokens)
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(
|
||||
self.tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = (
|
||||
timestamp_tokens[-2:].tolist() == [False, True])
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
self.reseek(
|
||||
current_segments, segment_size, single_timestamp_ending,
|
||||
tokens, timestamp_tokens, result)
|
||||
current_segments,
|
||||
segment_size,
|
||||
single_timestamp_ending,
|
||||
tokens,
|
||||
timestamp_tokens,
|
||||
result,
|
||||
)
|
||||
|
||||
if self.word_timestamps:
|
||||
if self.timestamp(
|
||||
current_segments, segment_size, single_timestamp_ending,
|
||||
mel_segment, previous_seek, content_frames):
|
||||
current_segments,
|
||||
segment_size,
|
||||
single_timestamp_ending,
|
||||
mel_segment,
|
||||
previous_seek,
|
||||
content_frames,
|
||||
):
|
||||
continue
|
||||
|
||||
if self.verbose:
|
||||
@ -657,24 +691,28 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
text = segment["text"]
|
||||
line = (
|
||||
f"[{format_timestamp(start)} --> "
|
||||
f"{format_timestamp(end)}] {text}")
|
||||
f"{format_timestamp(end)}] {text}"
|
||||
)
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or \
|
||||
segment["text"].strip() == "":
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
self.all_segments.extend([
|
||||
self.all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(self.all_segments))])
|
||||
self.all_tokens.extend([
|
||||
token for segment in current_segments
|
||||
for token in segment["tokens"]])
|
||||
current_segments, start=len(self.all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
if not self.condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
@ -686,9 +724,12 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
break
|
||||
|
||||
self.result = dict(
|
||||
segments=self.all_segments, language=self.language,
|
||||
segments=self.all_segments,
|
||||
language=self.language,
|
||||
text=self.tokenizer.decode(
|
||||
self.all_tokens[len(self.initial_prompt_tokens):]))
|
||||
self.all_tokens[len(self.initial_prompt_tokens) :]
|
||||
),
|
||||
)
|
||||
self.latest = None
|
||||
return self.result
|
||||
|
||||
@ -699,16 +740,17 @@ class Transcriber(metaclass=PassthroughPropertyDefaults):
|
||||
processing, seconds = 0, offset * HOP_LENGTH / SAMPLE_RATE
|
||||
while len(self.all_segments) > 0 and (
|
||||
self.all_segments[-1]["start"] >= seconds
|
||||
if len(self.all_segments) == 1 else
|
||||
self.all_segments[-2]["end"] > seconds):
|
||||
if len(self.all_segments) == 1
|
||||
else self.all_segments[-2]["end"] > seconds
|
||||
):
|
||||
rewriting = self.all_segments.pop()
|
||||
processing += len(rewriting["tokens"])
|
||||
self.all_tokens = self.all_tokens[: len(self.all_tokens) - processing]
|
||||
if len(self.all_segments) > 0 and (
|
||||
self.all_segments[-1]["start"] < seconds and
|
||||
self.all_segments[-1]["end"] >= seconds):
|
||||
self.seek = round(
|
||||
self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH)
|
||||
self.all_segments[-1]["start"] < seconds
|
||||
and self.all_segments[-1]["end"] >= seconds
|
||||
):
|
||||
self.seek = round(self.all_segments[-1]["end"] * SAMPLE_RATE / HOP_LENGTH)
|
||||
else:
|
||||
self.seek = offset
|
||||
|
||||
@ -728,12 +770,16 @@ def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
class MinimalTranscriber(Transcriber):
|
||||
exact: bool = True
|
||||
chlen: float = CHUNK_LENGTH
|
||||
|
||||
async def process(self, stream: ArrayStream, **kw) -> dict:
|
||||
data = await stream.request(self.chlen, self.exact)
|
||||
while data.shape[-1] > 0:
|
||||
self(data, stream.offset, True)
|
||||
t = self.chlen - (stream.offset + data.shape[-1] - self.seek) \
|
||||
/ FRAMES_PER_SECOND + CHUNK_LENGTH
|
||||
t = (
|
||||
self.chlen
|
||||
- (stream.offset + data.shape[-1] - self.seek) / FRAMES_PER_SECOND
|
||||
+ CHUNK_LENGTH
|
||||
)
|
||||
data = await stream.request(t, self.exact)
|
||||
return self.result
|
||||
|
||||
@ -755,12 +801,16 @@ class ProgressTranscriber(MinimalTranscriber):
|
||||
@PassthroughProperty(None).property
|
||||
def pbar(self):
|
||||
if self._pbar is None:
|
||||
n = self.latest.shape[-1] if self.duration is None \
|
||||
n = (
|
||||
self.latest.shape[-1]
|
||||
if self.duration is None
|
||||
else -int(self.duration * -FRAMES_PER_SECOND)
|
||||
)
|
||||
# show the progress bar when verbose is False
|
||||
# (if True, transcribed text will be printed)
|
||||
self._pbar = tqdm.tqdm(
|
||||
total=n, unit="frames", disable=self.verbose is not False)
|
||||
total=n, unit="frames", disable=self.verbose is not False
|
||||
)
|
||||
self._pbar.__enter__()
|
||||
return self._pbar
|
||||
|
||||
@ -784,10 +834,7 @@ class ProgressTranscriber(MinimalTranscriber):
|
||||
return await self.process(stream, **kw)
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
**kw):
|
||||
def transcribe(model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], **kw):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
|
||||
@ -5,16 +5,7 @@ import re
|
||||
import sys
|
||||
import time
|
||||
import zlib
|
||||
from typing import (
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
TextIO,
|
||||
Union,
|
||||
TypeVar,
|
||||
Generic,
|
||||
Any
|
||||
)
|
||||
from typing import Any, Callable, Generic, List, Optional, TextIO, TypeVar, Union
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@ -86,15 +77,17 @@ def format_timestamp(
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
def hms(sec: float) -> str:
|
||||
trim = sec < 3600
|
||||
h = "" if trim else str(int(sec) // 3600) + ":"
|
||||
m_fill = " " if trim else "0"
|
||||
m = " " if sec < 60 else str(int(sec) // 60 % 60).rjust(2, m_fill) + ":"
|
||||
s = str(int(sec) % 60).rjust(2, '0') + "."
|
||||
c = str(round((sec % 1) * 100)).rjust(2, '0')
|
||||
s = str(int(sec) % 60).rjust(2, "0") + "."
|
||||
c = str(round((sec % 1) * 100)).rjust(2, "0")
|
||||
return h + m + s + c
|
||||
|
||||
|
||||
def tod(seconds: float) -> str:
|
||||
return time.strftime("%H:%M:%S", time.localtime(seconds))
|
||||
|
||||
@ -349,28 +342,34 @@ def get_writer(
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# boilerplate for property with _{name} storage and passthrough getter/setter
|
||||
class PassthroughProperty(Generic[T]):
|
||||
def __init__(self, default: T):
|
||||
self.value = default
|
||||
|
||||
f: Optional[Callable[[Any, T], None]] = None
|
||||
|
||||
def setter(self, f: Callable[[Any, T], None]):
|
||||
self.f = f
|
||||
return self
|
||||
|
||||
g: Optional[property] = None
|
||||
|
||||
def property(self, g: Callable[[Any], T]):
|
||||
self.g = property(g)
|
||||
return self
|
||||
|
||||
|
||||
class PassthroughPropertyDefaults(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
def closure(f, v):
|
||||
def prop(self):
|
||||
return getattr(self, v)
|
||||
|
||||
def setter(self, value):
|
||||
setattr(self, v, value)
|
||||
|
||||
prop.__name__ = setter.__name__ = f
|
||||
return property(prop), setter
|
||||
|
||||
@ -383,4 +382,3 @@ class PassthroughPropertyDefaults(type):
|
||||
getter, setter = closure(k, private)
|
||||
updates[k] = (v.g or getter).setter(v.f or setter)
|
||||
return super().__new__(cls, clsname, bases, {**attrs, **updates})
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user