mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Merge b76dcf36630a5bd7bbc1851924e23e6128db693f into 517a43ecd132a2089d85f4ebc044728a71d49f6e
This commit is contained in:
commit
546c82147a
78
examples/test_prob.py
Normal file
78
examples/test_prob.py
Normal file
@ -0,0 +1,78 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
import torch
|
||||
import whisper
|
||||
import argparse
|
||||
import colorsys
|
||||
from whisper.utils import exact_div
|
||||
from typing import List
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
from colorama import init, Style
|
||||
|
||||
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio_from_source(audio_source):
|
||||
audio = whisper.load_audio(audio_source)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
|
||||
def decode_audio(model, audio, language="en", f16=True):
|
||||
dtype = torch.float16 if f16 else torch.float32
|
||||
# mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||
mel = whisper.log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device)
|
||||
mel_segment =whisper.pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
print('Decoding audio') # decode the audio
|
||||
options = whisper.DecodingOptions()
|
||||
result = whisper.decode(model, mel_segment, options)
|
||||
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
|
||||
|
||||
text_tokens = [tokenizer.decode([t]) for t in result.tokens]
|
||||
|
||||
return text_tokens, result.token_probs
|
||||
|
||||
def get_colored_text(text_tokens: List[int], token_probs: List[float]):
|
||||
init(autoreset=False) # Initialize colorama with autoreset=True to reset colors after each print
|
||||
output_text = ""
|
||||
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
||||
# Interpolate between red and green in the HSV color space
|
||||
r, g, b = colorsys.hsv_to_rgb(prob * (1/3), 1, 1)
|
||||
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
||||
color_code = f"\033[38;2;{r};{g};{b}m"
|
||||
colored_token = f"{color_code}{Style.BRIGHT}{str(token)}{Style.RESET_ALL}"
|
||||
output_text += colored_token
|
||||
return output_text
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--audio', type=str, help='the path of the audio file')
|
||||
parser.add_argument('--model', type=str, default="large", help='The version of the model to be used')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model = args.model
|
||||
audio = args.audio
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model(model)
|
||||
audio = load_audio_from_source(audio_source=audio)
|
||||
text, proba = decode_audio(model=model, audio=audio)
|
||||
print(get_colored_text(text, proba))
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .transcribe import transcribe, stt
|
||||
from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
|
@ -120,6 +120,7 @@ class DecodingResult:
|
||||
language: str
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
token_probs: List[float] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
@ -218,7 +219,7 @@ class TokenDecoder:
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
@ -245,7 +246,7 @@ class TokenDecoder:
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
@ -275,7 +276,7 @@ class GreedyDecoder(TokenDecoder):
|
||||
self.eot = eot
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
@ -283,19 +284,24 @@ class GreedyDecoder(TokenDecoder):
|
||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
probs = torch.exp(logprobs)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
current_token_probs = probs[torch.arange(probs.shape[0]), next_tokens]
|
||||
token_probs = torch.cat([token_probs, current_token_probs[:, None]], dim=-1)
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed, token_probs
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
token_probs = F.pad(token_probs, (0, 1), value=0)
|
||||
return tokens, sum_logprobs.tolist(), token_probs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
@ -381,7 +387,7 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
@ -682,6 +688,7 @@ class DecodingTask:
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
token_probs = torch.zeros_like(tokens).float()
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
@ -700,14 +707,14 @@ class DecodingTask:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
tokens, completed, token_probs = self.decoder.update(tokens, logits, sum_logprobs, token_probs)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
self.inference.cleanup_caching()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
return tokens, sum_logprobs, no_speech_probs, token_probs
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
@ -734,7 +741,7 @@ class DecodingTask:
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
@ -745,7 +752,7 @@ class DecodingTask:
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
@ -768,6 +775,7 @@ class DecodingTask:
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
token_probs
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
@ -782,8 +790,9 @@ class DecodingTask:
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
token_probs=token_probs[-len(tokens):]
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
@ -17,7 +17,7 @@ from .audio import (
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .decoding import DecodingOptions, DecodingResult, decode
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
@ -514,6 +514,40 @@ def transcribe(
|
||||
)
|
||||
|
||||
|
||||
def stt(model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
language : str = "en",
|
||||
f16: bool =True):
|
||||
"""
|
||||
Transcribe an audio file using Whisper while getting the probability for each token
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
language: string
|
||||
language used in the audio
|
||||
|
||||
f16: bool
|
||||
check if using torch.float16 otherwise use torch.float32
|
||||
"""
|
||||
dtype = torch.float16 if f16 else torch.float32
|
||||
audio = pad_or_trim(audio)
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device)
|
||||
mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
options = DecodingOptions()
|
||||
result = decode(model, mel_segment, options)
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
|
||||
text = [tokenizer.decode([t]) for t in result.tokens]
|
||||
output = [ [text, prob] for text, prob in zip(text, result.token_probs) ]
|
||||
return output
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user