From c1031a5787e7f21b789e9b84309d443d2fc7188a Mon Sep 17 00:00:00 2001 From: take0x <89313929+take0x@users.noreply.github.com> Date: Mon, 23 Sep 2024 08:06:27 +0900 Subject: [PATCH] Add mel_spectrogram_device parameter --- whisper/transcribe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d341528..d3a6283 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + mel_spectrogram_device: Optional[Union[str, torch.device]] = None, **decode_options, ): """ @@ -113,6 +114,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + mel_spectrogram_device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -131,7 +135,7 @@ def transcribe( # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram( - audio, model.dims.n_mels, padding=N_SAMPLES, device=model.device + audio, model.dims.n_mels, padding=N_SAMPLES, device=mel_spectrogram_device ) content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)