mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge e000892e137fb17d4ab43046c77853274eaf9ed2 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
4701f00b91
@ -1,19 +1,24 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("read_bytes", [True, False])
|
||||||
def test_audio():
|
def test_audio(read_bytes):
|
||||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||||
audio = load_audio(audio_path)
|
audio_input = audio_path
|
||||||
|
if (read_bytes):
|
||||||
|
with open(audio_path, 'rb') as f:
|
||||||
|
audio_input = f.read()
|
||||||
|
audio = load_audio(audio_input)
|
||||||
assert audio.ndim == 1
|
assert audio.ndim == 1
|
||||||
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
|
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
|
||||||
assert 0 < audio.std() < 1
|
assert 0 < audio.std() < 1
|
||||||
|
|
||||||
mel_from_audio = log_mel_spectrogram(audio)
|
mel_from_audio = log_mel_spectrogram(audio)
|
||||||
mel_from_file = log_mel_spectrogram(audio_path)
|
mel_from_file = log_mel_spectrogram(audio_input)
|
||||||
|
|
||||||
assert np.allclose(mel_from_audio, mel_from_file)
|
assert np.allclose(mel_from_audio, mel_from_file)
|
||||||
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
|
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run, PIPE
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -22,14 +22,14 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
|||||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
file: str
|
file: Union[str, bytes]
|
||||||
The audio file to open
|
The audio file to open, or the bytes content of an audio file
|
||||||
|
|
||||||
sr: int
|
sr: int
|
||||||
The sample rate to resample the audio if necessary
|
The sample rate to resample the audio if necessary
|
||||||
@ -46,7 +46,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
"ffmpeg",
|
"ffmpeg",
|
||||||
"-nostdin",
|
"-nostdin",
|
||||||
"-threads", "0",
|
"-threads", "0",
|
||||||
"-i", file,
|
|
||||||
"-f", "s16le",
|
"-f", "s16le",
|
||||||
"-ac", "1",
|
"-ac", "1",
|
||||||
"-acodec", "pcm_s16le",
|
"-acodec", "pcm_s16le",
|
||||||
@ -54,10 +53,18 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
"-"
|
"-"
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
try:
|
if isinstance(file, str):
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
cmd += ["-i", file]
|
||||||
except CalledProcessError as e:
|
try:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
else:
|
||||||
|
cmd += ["-i", "-"]
|
||||||
|
try:
|
||||||
|
out = run(cmd, input=file, stdout=PIPE, stderr=PIPE, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
@ -108,7 +115,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, bytes, np.ndarray, torch.Tensor],
|
||||||
n_mels: int = 80,
|
n_mels: int = 80,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
@ -136,7 +143,7 @@ def log_mel_spectrogram(
|
|||||||
A Tensor that contains the Mel spectrogram
|
A Tensor that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str) or isinstance(audio, bytes):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, bytes, np.ndarray, torch.Tensor],
|
||||||
*,
|
*,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
@ -63,7 +63,7 @@ def transcribe(
|
|||||||
The Whisper model instance
|
The Whisper model instance
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
The path to the audio file to open, or the audio waveform
|
The path to the audio file to open, or the audio waveform, or the bytes content of an audio file
|
||||||
|
|
||||||
verbose: bool
|
verbose: bool
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user