committed

This commit is contained in:
SinanAkkoyun 2023-03-23 02:25:21 +01:00
parent 5e6714ef11
commit 2ff7dbb41a
2 changed files with 4 additions and 1 deletions

View File

@ -40,7 +40,7 @@ result = whisper.decode(model, mel, options)
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""): def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
init(autoreset=True) # Initialize colorama init(autoreset=True) # Initialize colorama
text_tokens = [tokenizer.decode([t]) for t in tokens] text_tokens = [tokenizer.decode([t]) for t in tokens]
token_probs = token_probs[-len(text_tokens):] # token_probs = token_probs[-len(text_tokens):]
output_text = "" output_text = ""
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)): for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):

View File

@ -750,6 +750,9 @@ class DecodingTask:
for s in tokens for s in tokens
] ]
# fix token_probs length
token_probs = token_probs[-len(tokens):]
# select the top-ranked sample in each group # select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs) selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]