From e000892e137fb17d4ab43046c77853274eaf9ed2 Mon Sep 17 00:00:00 2001 From: nleve Date: Sat, 27 Jan 2024 02:15:23 -0500 Subject: [PATCH] allow processing already-in-memory audio file content --- tests/test_audio.py | 13 +++++++++---- whisper/audio.py | 29 ++++++++++++++++++----------- whisper/transcribe.py | 4 ++-- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/test_audio.py b/tests/test_audio.py index dfd78bc..2f97938 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1,19 +1,24 @@ import os.path import numpy as np +import pytest from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram - -def test_audio(): +@pytest.mark.parametrize("read_bytes", [True, False]) +def test_audio(read_bytes): audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") - audio = load_audio(audio_path) + audio_input = audio_path + if (read_bytes): + with open(audio_path, 'rb') as f: + audio_input = f.read() + audio = load_audio(audio_input) assert audio.ndim == 1 assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 assert 0 < audio.std() < 1 mel_from_audio = log_mel_spectrogram(audio) - mel_from_file = log_mel_spectrogram(audio_path) + mel_from_file = log_mel_spectrogram(audio_input) assert np.allclose(mel_from_audio, mel_from_file) assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66a..fedc194 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from subprocess import CalledProcessError, run +from subprocess import CalledProcessError, run, PIPE from typing import Optional, Union import numpy as np @@ -22,14 +22,14 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE): """ Open an audio file and read as mono waveform, resampling as necessary Parameters ---------- - file: str - The audio file to open + file: Union[str, bytes] + The audio file to open, or the bytes content of an audio file sr: int The sample rate to resample the audio if necessary @@ -46,7 +46,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): "ffmpeg", "-nostdin", "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", @@ -54,10 +53,18 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): "-" ] # fmt: on - try: - out = run(cmd, capture_output=True, check=True).stdout - except CalledProcessError as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + if isinstance(file, str): + cmd += ["-i", file] + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + else: + cmd += ["-i", "-"] + try: + out = run(cmd, input=file, stdout=PIPE, stderr=PIPE, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 @@ -108,7 +115,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor: def log_mel_spectrogram( - audio: Union[str, np.ndarray, torch.Tensor], + audio: Union[str, bytes, np.ndarray, torch.Tensor], n_mels: int = 80, padding: int = 0, device: Optional[Union[str, torch.device]] = None, @@ -136,7 +143,7 @@ def log_mel_spectrogram( A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): - if isinstance(audio, str): + if isinstance(audio, str) or isinstance(audio, bytes): audio = load_audio(audio) audio = torch.from_numpy(audio) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..fa94417 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: def transcribe( model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor], + audio: Union[str, bytes, np.ndarray, torch.Tensor], *, verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), @@ -62,7 +62,7 @@ def transcribe( The Whisper model instance audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform + The path to the audio file to open, or the audio waveform, or the bytes content of an audio file verbose: bool Whether to display the text being decoded to the console. If True, displays all the details,