diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66a..f28e0fd 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -24,24 +24,20 @@ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audi def load_audio(file: str, sr: int = SAMPLE_RATE): """ - Open an audio file and read as mono waveform, resampling as necessary + Open an audio file and read as mono waveform, resampling as necessary. Parameters ---------- file: str - The audio file to open + The audio file to open. sr: int - The sample rate to resample the audio if necessary + The sample rate to resample the audio if necessary. Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ - - # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. - # fmt: off cmd = [ "ffmpeg", "-nostdin", @@ -53,7 +49,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): "-ar", str(sr), "-" ] - # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: @@ -65,6 +60,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + + Parameters + ---------- + array: Union[np.ndarray, torch.Tensor] + The audio array to pad or trim. + + length: int + The desired length of the audio array. + + axis: int + The axis along which to pad or trim. + + Returns + ------- + A padded or trimmed array. """ if torch.is_tensor(array): if array.shape[axis] > length: @@ -91,14 +101,20 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) def mel_filters(device, n_mels: int) -> torch.Tensor: """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using: + Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - np.savez_compressed( - "mel_filters.npz", - mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), - mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), - ) + Parameters + ---------- + device: torch.device + The device to load the filters on. + + n_mels: int + The number of Mel-frequency filters. + + Returns + ------- + torch.Tensor + The Mel filterbank matrix. """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" @@ -114,44 +130,48 @@ def log_mel_spectrogram( device: Optional[Union[str, torch.device]] = None, ): """ - Compute the log-Mel spectrogram of + Compute the log-Mel spectrogram of the audio. Parameters ---------- - audio: Union[str, np.ndarray, torch.Tensor], shape = (*) - The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + audio: Union[str, np.ndarray, torch.Tensor] + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz. n_mels: int - The number of Mel-frequency filters, only 80 is supported + The number of Mel-frequency filters. padding: int - Number of zero samples to pad to the right + Number of zero samples to pad to the right. device: Optional[Union[str, torch.device]] - If given, the audio tensor is moved to this device before STFT + If given, the audio tensor is moved to this device before STFT. Returns ------- - torch.Tensor, shape = (80, n_frames) - A Tensor that contains the Mel spectrogram + torch.Tensor + A Tensor that contains the Mel spectrogram. """ - if not torch.is_tensor(audio): - if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) + try: + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) - if device is not None: - audio = audio.to(device) - if padding > 0: - audio = F.pad(audio, (0, padding)) - window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 - filters = mel_filters(audio.device, n_mels) - mel_spec = filters @ magnitudes + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - return log_spec + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + except Exception as e: + print(f"Error computing log-mel spectrogram: {e}") + return None