diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc36..ee55060 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -52,6 +52,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + mel_spectrogram_device: Optional[Union[str, torch.device]] = None, **decode_options, ): """ @@ -119,6 +120,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + mel_spectrogram_device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -136,7 +140,9 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) + mel = log_mel_spectrogram( + audio, model.dims.n_mels, padding=N_SAMPLES, device=mel_spectrogram_device + ) content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)