Add mel_spectrogram_device parameter

This commit is contained in:
take0x 2024-09-23 08:06:27 +09:00
parent 834662c956
commit c1031a5787

View File

@ -51,6 +51,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
mel_spectrogram_device: Optional[Union[str, torch.device]] = None,
**decode_options, **decode_options,
): ):
""" """
@ -113,6 +114,9 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds) When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected when a possible hallucination is detected
mel_spectrogram_device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -131,7 +135,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram( mel = log_mel_spectrogram(
audio, model.dims.n_mels, padding=N_SAMPLES, device=model.device audio, model.dims.n_mels, padding=N_SAMPLES, device=mel_spectrogram_device
) )
content_frames = mel.shape[-1] - N_FRAMES content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)