From 2ff7dbb41a3b21cd8f6c6fa26a9d300f1d95f4d8 Mon Sep 17 00:00:00 2001 From: SinanAkkoyun Date: Thu, 23 Mar 2023 02:25:21 +0100 Subject: [PATCH] committed --- examples/confidence_per_token.py | 2 +- whisper/decoding.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/confidence_per_token.py b/examples/confidence_per_token.py index a6c90d0..8d8ed50 100644 --- a/examples/confidence_per_token.py +++ b/examples/confidence_per_token.py @@ -40,7 +40,7 @@ result = whisper.decode(model, mel, options) def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""): init(autoreset=True) # Initialize colorama text_tokens = [tokenizer.decode([t]) for t in tokens] - token_probs = token_probs[-len(text_tokens):] + # token_probs = token_probs[-len(text_tokens):] output_text = "" for i, (token, prob) in enumerate(zip(text_tokens, token_probs)): diff --git a/whisper/decoding.py b/whisper/decoding.py index f8e8200..29c7801 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -750,6 +750,9 @@ class DecodingTask: for s in tokens ] + # fix token_probs length + token_probs = token_probs[-len(tokens):] + # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]