Merge e000892e137fb17d4ab43046c77853274eaf9ed2 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
nleve 2025-06-27 02:27:48 +00:00 committed by GitHub
commit 4701f00b91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 17 deletions

View File

@ -1,19 +1,24 @@
import os.path import os.path
import numpy as np import numpy as np
import pytest
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
@pytest.mark.parametrize("read_bytes", [True, False])
def test_audio(): def test_audio(read_bytes):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path) audio_input = audio_path
if (read_bytes):
with open(audio_path, 'rb') as f:
audio_input = f.read()
audio = load_audio(audio_input)
assert audio.ndim == 1 assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1 assert 0 < audio.std() < 1
mel_from_audio = log_mel_spectrogram(audio) mel_from_audio = log_mel_spectrogram(audio)
mel_from_file = log_mel_spectrogram(audio_path) mel_from_file = log_mel_spectrogram(audio_input)
assert np.allclose(mel_from_audio, mel_from_file) assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 assert mel_from_audio.max() - mel_from_audio.min() <= 2.0

View File

@ -1,6 +1,6 @@
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run, PIPE
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
@ -22,14 +22,14 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary
Parameters Parameters
---------- ----------
file: str file: Union[str, bytes]
The audio file to open The audio file to open, or the bytes content of an audio file
sr: int sr: int
The sample rate to resample the audio if necessary The sample rate to resample the audio if necessary
@ -46,7 +46,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"ffmpeg", "ffmpeg",
"-nostdin", "-nostdin",
"-threads", "0", "-threads", "0",
"-i", file,
"-f", "s16le", "-f", "s16le",
"-ac", "1", "-ac", "1",
"-acodec", "pcm_s16le", "-acodec", "pcm_s16le",
@ -54,10 +53,18 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"-" "-"
] ]
# fmt: on # fmt: on
try: if isinstance(file, str):
out = run(cmd, capture_output=True, check=True).stdout cmd += ["-i", file]
except CalledProcessError as e: try:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
cmd += ["-i", "-"]
try:
out = run(cmd, input=file, stdout=PIPE, stderr=PIPE, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@ -108,7 +115,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, bytes, np.ndarray, torch.Tensor],
n_mels: int = 80, n_mels: int = 80,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
@ -136,7 +143,7 @@ def log_mel_spectrogram(
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
if isinstance(audio, str): if isinstance(audio, str) or isinstance(audio, bytes):
audio = load_audio(audio) audio = load_audio(audio)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)

View File

@ -37,7 +37,7 @@ if TYPE_CHECKING:
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor], audio: Union[str, bytes, np.ndarray, torch.Tensor],
*, *,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
@ -63,7 +63,7 @@ def transcribe(
The Whisper model instance The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor] audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform The path to the audio file to open, or the audio waveform, or the bytes content of an audio file
verbose: bool verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details, Whether to display the text being decoded to the console. If True, displays all the details,