Support longer audio files reducing memory usage with chunking

This commit is contained in:
Gustavo Garcia 2024-07-01 19:46:15 +02:00
parent ba3f3cd54b
commit 20e323895d
3 changed files with 359 additions and 339 deletions

View File

@ -7,13 +7,13 @@ from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio(): def test_audio():
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path) audio = next(load_audio(audio_path))
assert audio.ndim == 1 assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1 assert 0 < audio.std() < 1
mel_from_audio = log_mel_spectrogram(audio) mel_from_audio = next(log_mel_spectrogram(audio))
mel_from_file = log_mel_spectrogram(audio_path) mel_from_file = next(log_mel_spectrogram(audio_path))
assert np.allclose(mel_from_audio, mel_from_file) assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 assert mel_from_audio.max() - mel_from_audio.min() <= 2.0

View File

@ -1,7 +1,8 @@
import os import os
import subprocess
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Optional, Union from typing import Generator, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -21,6 +22,7 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
MAX_CHUNK_DURATION = 2 * 60 * 60 # 2 hour maximum chunk duration
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
@ -55,11 +57,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
] ]
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 while True:
out = process.stdout.read(MAX_CHUNK_DURATION * sr * 2)
if not out:
break
yield np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@ -108,7 +114,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
n_mels: int = 80, n_mels: int = 80,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
@ -135,13 +141,26 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
audio = torch.from_numpy(audio) elif isinstance(audio, np.ndarray):
audio = [audio]
elif isinstance(audio, torch.Tensor):
audio = [audio]
for chunk in audio:
if not isinstance(chunk, torch.Tensor):
chunk = torch.from_numpy(chunk)
if device is not None: if device is not None:
audio = audio.to(device) chunk = chunk.to(device)
yield _log_mel_spectrogram(chunk, n_mels, padding)
def _log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
):
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)

View File

@ -2,7 +2,7 @@ import argparse
import os import os
import traceback import traceback
import warnings import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -37,7 +37,7 @@ if TYPE_CHECKING:
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
*, *,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
@ -129,8 +129,13 @@ def transcribe(
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
all_tokens = []
all_segments = []
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) mels = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
for mel in mels:
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
@ -170,8 +175,6 @@ def transcribe(
seek_points.append(content_frames) seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate": if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")
@ -223,8 +226,6 @@ def transcribe(
time_precision = ( time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
if initial_prompt is not None: if initial_prompt is not None: