allow processing already-in-memory audio file content

This commit is contained in:
nleve 2024-01-27 02:15:23 -05:00
parent ba3f3cd54b
commit e000892e13
3 changed files with 29 additions and 17 deletions

View File

@ -1,19 +1,24 @@
import os.path
import numpy as np
import pytest
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio():
@pytest.mark.parametrize("read_bytes", [True, False])
def test_audio(read_bytes):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path)
audio_input = audio_path
if (read_bytes):
with open(audio_path, 'rb') as f:
audio_input = f.read()
audio = load_audio(audio_input)
assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1
mel_from_audio = log_mel_spectrogram(audio)
mel_from_file = log_mel_spectrogram(audio_path)
mel_from_file = log_mel_spectrogram(audio_input)
assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0

View File

@ -1,6 +1,6 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from subprocess import CalledProcessError, run, PIPE
from typing import Optional, Union
import numpy as np
@ -22,14 +22,14 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
file: Union[str, bytes]
The audio file to open, or the bytes content of an audio file
sr: int
The sample rate to resample the audio if necessary
@ -46,7 +46,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
@ -54,10 +53,18 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
if isinstance(file, str):
cmd += ["-i", file]
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
cmd += ["-i", "-"]
try:
out = run(cmd, input=file, stdout=PIPE, stderr=PIPE, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@ -108,7 +115,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, bytes, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
@ -136,7 +143,7 @@ def log_mel_spectrogram(
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
if isinstance(audio, str) or isinstance(audio, bytes):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

View File

@ -37,7 +37,7 @@ if TYPE_CHECKING:
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, bytes, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
@ -62,7 +62,7 @@ def transcribe(
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
The path to the audio file to open, or the audio waveform, or the bytes content of an audio file
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,