mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 08:11:11 +00:00
add stt method to the transcribe function
This commit is contained in:
parent
0fd17c99c8
commit
e53b617de1
@ -17,7 +17,7 @@ from .audio import (
|
|||||||
log_mel_spectrogram,
|
log_mel_spectrogram,
|
||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
)
|
)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult, decode
|
||||||
from .timing import add_word_timestamps
|
from .timing import add_word_timestamps
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -498,6 +498,40 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stt(model: "Whisper",
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
language : str = "en",
|
||||||
|
f16: bool =True):
|
||||||
|
"""
|
||||||
|
Transcribe an audio file using Whisper while getting the probability for each token
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
The Whisper model instance
|
||||||
|
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
|
language: string
|
||||||
|
language used in the audio
|
||||||
|
|
||||||
|
f16: bool
|
||||||
|
check if using torch.float16 otherwise use torch.float32
|
||||||
|
"""
|
||||||
|
dtype = torch.float16 if f16 else torch.float32
|
||||||
|
audio = pad_or_trim(audio)
|
||||||
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device)
|
||||||
|
mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
|
options = DecodingOptions()
|
||||||
|
result = decode(model, mel_segment, options)
|
||||||
|
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
|
||||||
|
text = [tokenizer.decode([t]) for t in result.tokens]
|
||||||
|
output = [ [text, prob] for text, prob in zip(text, result.token_probs) ]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
from . import available_models
|
from . import available_models
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user