From a4aadd95a1ed06da953bd741ed7e1194263b6930 Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sat, 25 May 2024 21:45:13 +0100 Subject: [PATCH 1/5] add test_proba file and change decoder --- examples/test_prob.py | 48 +++++++++++++++++++++++++++++++++++++++++++ whisper/decoding.py | 35 +++++++++++++++++++------------ 2 files changed, 70 insertions(+), 13 deletions(-) create mode 100644 examples/test_prob.py diff --git a/examples/test_prob.py b/examples/test_prob.py new file mode 100644 index 0000000..babfd0d --- /dev/null +++ b/examples/test_prob.py @@ -0,0 +1,48 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import whisper +import argparse +import colorsys +from typing import List +from whisper.tokenizer import get_tokenizer +from colorama import init, Style + + + +def load_audio_from_source(audio_source): + audio = whisper.load_audio(audio_source) + audio = whisper.pad_or_trim(audio) + return audio + + +def decode_audio(model, audio, language="en"): + mel = whisper.log_mel_spectrogram(audio).to(model.device) + print('Decoding audio') # decode the audio + options = whisper.DecodingOptions() + result = whisper.decode(model, mel, options) + + tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task) + + text_tokens = [tokenizer.decode([t]) for t in result.tokens] + + return text_tokens, result.token_probs + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--audio', type=str, help='the path of the audio file') + parser.add_argument('--model', type=str, default="large", help='The version of the model to be used') + + + args = parser.parse_args() + + model = args.model + audio = args.audio + + # Load model + model = whisper.load_model(model) + audio = load_audio_from_source(audio_source=audio) + text, proba = decode_audio(model=model, audio=audio) + print(text) + print(proba) diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..c287f33 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -120,6 +120,7 @@ class DecodingResult: language: str language_probs: Optional[Dict[str, float]] = None tokens: List[int] = field(default_factory=list) + token_probs: List[float] = field(default_factory=list) text: str = "" avg_logprob: float = np.nan no_speech_prob: float = np.nan @@ -218,7 +219,7 @@ class TokenDecoder: """Initialize any stateful variables for decoding a new sequence""" def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor ) -> Tuple[Tensor, bool]: """Specify how to select the next token, based on the current trace and logits @@ -245,7 +246,7 @@ class TokenDecoder: raise NotImplementedError def finalize( - self, tokens: Tensor, sum_logprobs: Tensor + self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: """Finalize search and return the final candidate sequences @@ -275,7 +276,7 @@ class GreedyDecoder(TokenDecoder): self.eot = eot def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor ) -> Tuple[Tensor, bool]: if self.temperature == 0: next_tokens = logits.argmax(dim=-1) @@ -283,19 +284,24 @@ class GreedyDecoder(TokenDecoder): next_tokens = Categorical(logits=logits / self.temperature).sample() logprobs = F.log_softmax(logits.float(), dim=-1) + probs = torch.exp(logprobs) current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) next_tokens[tokens[:, -1] == self.eot] = self.eot tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) - completed = (tokens[:, -1] == self.eot).all() - return tokens, completed + current_token_probs = probs[torch.arange(probs.shape[0]), next_tokens] + token_probs = torch.cat([token_probs, current_token_probs[:, None]], dim=-1) - def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed, token_probs + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor): # make sure each sequence has at least one EOT token at the end tokens = F.pad(tokens, (0, 1), value=self.eot) - return tokens, sum_logprobs.tolist() + token_probs = F.pad(token_probs, (0, 1), value=0) + return tokens, sum_logprobs.tolist(), token_probs.tolist() class BeamSearchDecoder(TokenDecoder): @@ -381,7 +387,7 @@ class BeamSearchDecoder(TokenDecoder): ) return tokens, completed - def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor): # collect all finished sequences, including patience, and add unfinished ones if not enough sum_logprobs = sum_logprobs.cpu() for i, sequences in enumerate(self.finished_sequences): @@ -682,6 +688,7 @@ class DecodingTask: sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch + token_probs = torch.zeros_like(tokens).float() try: for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) @@ -700,14 +707,14 @@ class DecodingTask: logit_filter.apply(logits, tokens) # expand the tokens tensor with the selected next tokens - tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + tokens, completed, token_probs = self.decoder.update(tokens, logits, sum_logprobs, token_probs) if completed or tokens.shape[-1] > self.n_ctx: break finally: self.inference.cleanup_caching() - return tokens, sum_logprobs, no_speech_probs + return tokens, sum_logprobs, no_speech_probs, token_probs @torch.no_grad() def run(self, mel: Tensor) -> List[DecodingResult]: @@ -734,7 +741,7 @@ class DecodingTask: tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop - tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] @@ -745,7 +752,7 @@ class DecodingTask: sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) # get the final candidates for each group, and slice between the first sampled token and EOT - tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs) tokens: List[List[Tensor]] = [ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens @@ -768,6 +775,7 @@ class DecodingTask: audio_features, avg_logprobs, no_speech_probs, + token_probs ) if len(set(map(len, fields))) != 1: raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") @@ -782,8 +790,9 @@ class DecodingTask: no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text), + token_probs=token_probs[-len(tokens):] ) - for text, language, tokens, features, avg_logprob, no_speech_prob in zip( + for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip( *fields ) ] From 34d9a9b6a3c2c6d8c12f886c127c6a7ebe2ef8e4 Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sun, 26 May 2024 11:03:22 +0100 Subject: [PATCH 2/5] modify the mel to be in the same shape as model's --- examples/test_prob.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/test_prob.py b/examples/test_prob.py index babfd0d..130762b 100644 --- a/examples/test_prob.py +++ b/examples/test_prob.py @@ -2,26 +2,45 @@ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +import torch import whisper import argparse import colorsys +from whisper.utils import exact_div from typing import List from whisper.tokenizer import get_tokenizer from colorama import init, Style +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + def load_audio_from_source(audio_source): audio = whisper.load_audio(audio_source) audio = whisper.pad_or_trim(audio) return audio -def decode_audio(model, audio, language="en"): - mel = whisper.log_mel_spectrogram(audio).to(model.device) +def decode_audio(model, audio, language="en", f16=True): + dtype = torch.float16 if f16 else torch.float32 + # mel = whisper.log_mel_spectrogram(audio).to(model.device) + mel = whisper.log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device) + mel_segment =whisper.pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + print('Decoding audio') # decode the audio options = whisper.DecodingOptions() - result = whisper.decode(model, mel, options) + result = whisper.decode(model, mel_segment, options) tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task) From 0fd17c99c8f064f4cda750f9db7b788530d80c6f Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sun, 26 May 2024 11:28:35 +0100 Subject: [PATCH 3/5] add colorized text --- examples/test_prob.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/test_prob.py b/examples/test_prob.py index 130762b..f6beebd 100644 --- a/examples/test_prob.py +++ b/examples/test_prob.py @@ -48,6 +48,18 @@ def decode_audio(model, audio, language="en", f16=True): return text_tokens, result.token_probs +def get_colored_text(text_tokens: List[int], token_probs: List[float]): + init(autoreset=False) # Initialize colorama with autoreset=True to reset colors after each print + output_text = "" + for i, (token, prob) in enumerate(zip(text_tokens, token_probs)): + # Interpolate between red and green in the HSV color space + r, g, b = colorsys.hsv_to_rgb(prob * (1/3), 1, 1) + r, g, b = int(r * 255), int(g * 255), int(b * 255) + color_code = f"\033[38;2;{r};{g};{b}m" + colored_token = f"{color_code}{Style.BRIGHT}{str(token)}{Style.RESET_ALL}" + output_text += colored_token + return output_text + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--audio', type=str, help='the path of the audio file') @@ -63,5 +75,4 @@ if __name__ == '__main__': model = whisper.load_model(model) audio = load_audio_from_source(audio_source=audio) text, proba = decode_audio(model=model, audio=audio) - print(text) - print(proba) + print(get_colored_text(text, proba)) \ No newline at end of file From e53b617de1a6b5971a1ae94f139ac59e04324695 Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sun, 26 May 2024 19:10:24 +0100 Subject: [PATCH 4/5] add stt method to the transcribe function --- whisper/transcribe.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a2..6f2529d 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -17,7 +17,7 @@ from .audio import ( log_mel_spectrogram, pad_or_trim, ) -from .decoding import DecodingOptions, DecodingResult +from .decoding import DecodingOptions, DecodingResult, decode from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import ( @@ -498,6 +498,40 @@ def transcribe( ) +def stt(model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + language : str = "en", + f16: bool =True): + """ + Transcribe an audio file using Whisper while getting the probability for each token + + Parameters + ---------- + model: Whisper + The Whisper model instance + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + language: string + language used in the audio + + f16: bool + check if using torch.float16 otherwise use torch.float32 + """ + dtype = torch.float16 if f16 else torch.float32 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES).to(model.device) + mel_segment =pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + + options = DecodingOptions() + result = decode(model, mel_segment, options) + tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task) + text = [tokenizer.decode([t]) for t in result.tokens] + output = [ [text, prob] for text, prob in zip(text, result.token_probs) ] + return output + + def cli(): from . import available_models From b76dcf36630a5bd7bbc1851924e23e6128db693f Mon Sep 17 00:00:00 2001 From: khaliladib11 <73353537+Khaliladib11@users.noreply.github.com> Date: Sun, 26 May 2024 19:10:51 +0100 Subject: [PATCH 5/5] export stt method, add it to the __init__ function --- whisper/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index d7fbba3..caa36de 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -11,7 +11,7 @@ from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .model import ModelDimensions, Whisper -from .transcribe import transcribe +from .transcribe import transcribe, stt from .version import __version__ _MODELS = {