Merge e000892e137fb17d4ab43046c77853274eaf9ed2 into 517a43ecd132a2089d85f4ebc044728a71d49f6e

This commit is contained in:
nlev 2025-01-07 12:45:09 +01:00 committed by GitHub
commit adf281518a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 17 deletions

View File

@ -1,19 +1,24 @@
import os.path
import numpy as np
import pytest
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio():
@pytest.mark.parametrize("read_bytes", [True, False])
def test_audio(read_bytes):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path)
audio_input = audio_path
if (read_bytes):
with open(audio_path, 'rb') as f:
audio_input = f.read()
audio = load_audio(audio_input)
assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1
mel_from_audio = log_mel_spectrogram(audio)
mel_from_file = log_mel_spectrogram(audio_path)
mel_from_file = log_mel_spectrogram(audio_input)
assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0

View File

@ -1,6 +1,6 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from subprocess import CalledProcessError, run, PIPE
from typing import Optional, Union
import numpy as np
@ -22,14 +22,14 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
file: Union[str, bytes]
The audio file to open, or the bytes content of an audio file
sr: int
The sample rate to resample the audio if necessary
@ -46,7 +46,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
@ -54,10 +53,18 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
if isinstance(file, str):
cmd += ["-i", file]
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
cmd += ["-i", "-"]
try:
out = run(cmd, input=file, stdout=PIPE, stderr=PIPE, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@ -108,7 +115,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, bytes, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
@ -136,7 +143,7 @@ def log_mel_spectrogram(
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
if isinstance(audio, str) or isinstance(audio, bytes):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

View File

@ -37,7 +37,7 @@ if TYPE_CHECKING:
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, bytes, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
@ -63,7 +63,7 @@ def transcribe(
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
The path to the audio file to open, or the audio waveform, or the bytes content of an audio file
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,