mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 07:48:45 +00:00
committed
This commit is contained in:
parent
5e6714ef11
commit
2ff7dbb41a
@ -40,7 +40,7 @@ result = whisper.decode(model, mel, options)
|
|||||||
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
|
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
|
||||||
init(autoreset=True) # Initialize colorama
|
init(autoreset=True) # Initialize colorama
|
||||||
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
||||||
token_probs = token_probs[-len(text_tokens):]
|
# token_probs = token_probs[-len(text_tokens):]
|
||||||
|
|
||||||
output_text = ""
|
output_text = ""
|
||||||
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
||||||
|
|||||||
@ -750,6 +750,9 @@ class DecodingTask:
|
|||||||
for s in tokens
|
for s in tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# fix token_probs length
|
||||||
|
token_probs = token_probs[-len(tokens):]
|
||||||
|
|
||||||
# select the top-ranked sample in each group
|
# select the top-ranked sample in each group
|
||||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user