mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Support longer audio files reducing memory usage with chunking
This commit is contained in:
parent
ba3f3cd54b
commit
20e323895d
@ -7,13 +7,13 @@ from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
|||||||
|
|
||||||
def test_audio():
|
def test_audio():
|
||||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||||
audio = load_audio(audio_path)
|
audio = next(load_audio(audio_path))
|
||||||
assert audio.ndim == 1
|
assert audio.ndim == 1
|
||||||
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
|
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
|
||||||
assert 0 < audio.std() < 1
|
assert 0 < audio.std() < 1
|
||||||
|
|
||||||
mel_from_audio = log_mel_spectrogram(audio)
|
mel_from_audio = next(log_mel_spectrogram(audio))
|
||||||
mel_from_file = log_mel_spectrogram(audio_path)
|
mel_from_file = next(log_mel_spectrogram(audio_path))
|
||||||
|
|
||||||
assert np.allclose(mel_from_audio, mel_from_file)
|
assert np.allclose(mel_from_audio, mel_from_file)
|
||||||
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
|
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
from typing import Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -21,6 +22,7 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
|||||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
MAX_CHUNK_DURATION = 2 * 60 * 60 # 2 hour maximum chunk duration
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
"""
|
"""
|
||||||
@ -55,11 +57,15 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
except CalledProcessError as e:
|
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
||||||
|
|
||||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
while True:
|
||||||
|
out = process.stdout.read(MAX_CHUNK_DURATION * sr * 2)
|
||||||
|
if not out:
|
||||||
|
break
|
||||||
|
yield np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
|
||||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||||
@ -108,7 +114,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
|
||||||
n_mels: int = 80,
|
n_mels: int = 80,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
@ -135,13 +141,26 @@ def log_mel_spectrogram(
|
|||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio = torch.from_numpy(audio)
|
elif isinstance(audio, np.ndarray):
|
||||||
|
audio = [audio]
|
||||||
|
elif isinstance(audio, torch.Tensor):
|
||||||
|
audio = [audio]
|
||||||
|
|
||||||
|
for chunk in audio:
|
||||||
|
if not isinstance(chunk, torch.Tensor):
|
||||||
|
chunk = torch.from_numpy(chunk)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
audio = audio.to(device)
|
chunk = chunk.to(device)
|
||||||
|
yield _log_mel_spectrogram(chunk, n_mels, padding)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_mel_spectrogram(
|
||||||
|
audio: torch.Tensor,
|
||||||
|
n_mels: int = 80,
|
||||||
|
padding: int = 0,
|
||||||
|
):
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
audio = F.pad(audio, (0, padding))
|
audio = F.pad(audio, (0, padding))
|
||||||
window = torch.hann_window(N_FFT).to(audio.device)
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
|
||||||
*,
|
*,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
@ -129,8 +129,13 @@ def transcribe(
|
|||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
|
all_tokens = []
|
||||||
|
all_segments = []
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
mels = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
for mel in mels:
|
||||||
content_frames = mel.shape[-1] - N_FRAMES
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
@ -170,8 +175,6 @@ def transcribe(
|
|||||||
seek_points.append(content_frames)
|
seek_points.append(content_frames)
|
||||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
|
||||||
|
|
||||||
if word_timestamps and task == "translate":
|
if word_timestamps and task == "translate":
|
||||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
||||||
@ -223,8 +226,6 @@ def transcribe(
|
|||||||
time_precision = (
|
time_precision = (
|
||||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
) # time per output token: 0.02 (seconds)
|
) # time per output token: 0.02 (seconds)
|
||||||
all_tokens = []
|
|
||||||
all_segments = []
|
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
if initial_prompt is not None:
|
if initial_prompt is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user