mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 08:11:11 +00:00
add stt method to the transcribe function
This commit is contained in:
parent
0fd17c99c8
commit
e53b617de1
@ -17,7 +17,7 @@ from .audio import (
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .decoding import DecodingOptions, DecodingResult, decode
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
@ -498,6 +498,40 @@ def transcribe(
|
||||
)
|
||||
|
||||
|
||||
def stt(model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
language : str = "en",
|
||||
f16: bool =True):
|
||||
"""
|
||||
Transcribe an audio file using Whisper while getting the probability for each token
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
language: string
|
||||
language used in the audio
|
||||
|
||||
f16: bool
|
||||
check if using torch.float16 otherwise use torch.float32
|
||||
"""
|
||||
dtype = torch.float16 if f16 else torch.float32
|
||||
audio = pad_or_trim(audio)
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device)
|
||||
mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
options = DecodingOptions()
|
||||
result = decode(model, mel_segment, options)
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
|
||||
text = [tokenizer.decode([t]) for t in result.tokens]
|
||||
output = [ [text, prob] for text, prob in zip(text, result.token_probs) ]
|
||||
return output
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user