From 6750a98bdd2072626ff352474950f5f36f1919e5 Mon Sep 17 00:00:00 2001 From: SinanAkkoyun Date: Thu, 23 Mar 2023 02:42:25 +0100 Subject: [PATCH] Fixed token_prob length! :) --- examples/confidence_per_token.py | 1 - whisper/decoding.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/confidence_per_token.py b/examples/confidence_per_token.py index 8d8ed50..1150c30 100644 --- a/examples/confidence_per_token.py +++ b/examples/confidence_per_token.py @@ -40,7 +40,6 @@ result = whisper.decode(model, mel, options) def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""): init(autoreset=True) # Initialize colorama text_tokens = [tokenizer.decode([t]) for t in tokens] - # token_probs = token_probs[-len(text_tokens):] output_text = "" for i, (token, prob) in enumerate(zip(text_tokens, token_probs)): diff --git a/whisper/decoding.py b/whisper/decoding.py index 29c7801..4ded71b 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -750,9 +750,6 @@ class DecodingTask: for s in tokens ] - # fix token_probs length - token_probs = token_probs[-len(tokens):] - # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] @@ -785,7 +782,7 @@ class DecodingTask: no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text), - token_probs=token_probs + token_probs=token_probs[-len(tokens):] ) for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip( *fields