add stt method to the transcribe function

This commit is contained in:
khaliladib11 2024-05-26 19:10:24 +01:00
parent 0fd17c99c8
commit e53b617de1
No known key found for this signature in database

View File

@ -17,7 +17,7 @@ from .audio import (
log_mel_spectrogram, log_mel_spectrogram,
pad_or_trim, pad_or_trim,
) )
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult, decode
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import ( from .utils import (
@ -498,6 +498,40 @@ def transcribe(
) )
def stt(model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
language : str = "en",
f16: bool =True):
"""
Transcribe an audio file using Whisper while getting the probability for each token
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
language: string
language used in the audio
f16: bool
check if using torch.float16 otherwise use torch.float32
"""
dtype = torch.float16 if f16 else torch.float32
audio = pad_or_trim(audio)
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device)
mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
options = DecodingOptions()
result = decode(model, mel_segment, options)
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
text = [tokenizer.decode([t]) for t in result.tokens]
output = [ [text, prob] for text, prob in zip(text, result.token_probs) ]
return output
def cli(): def cli():
from . import available_models from . import available_models