From e53b617de1a6b5971a1ae94f139ac59e04324695 Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sun, 26 May 2024 19:10:24 +0100 Subject: [PATCH] add stt method to the transcribe function --- whisper/transcribe.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..6f2529d 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -17,7 +17,7 @@ from .audio import ( log_mel_spectrogram, pad_or_trim, ) -from .decoding import DecodingOptions, DecodingResult +from .decoding import DecodingOptions, DecodingResult, decode from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import ( @@ -498,6 +498,40 @@ def transcribe( ) +def stt(model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + language : str = "en", + f16: bool =True): + """ + Transcribe an audio file using Whisper while getting the probability for each token + + Parameters + ---------- + model: Whisper + The Whisper model instance + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + language: string + language used in the audio + + f16: bool + check if using torch.float16 otherwise use torch.float32 + """ + dtype = torch.float16 if f16 else torch.float32 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device) + mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + + options = DecodingOptions() + result = decode(model, mel_segment, options) + tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task) + text = [tokenizer.decode([t]) for t in result.tokens] + output = [ [text, prob] for text, prob in zip(text, result.token_probs) ] + return output + + def cli(): from . import available_models