mirror of
https://github.com/openai/whisper.git
synced 2025-11-26 23:46:09 +00:00
Fixed token_prob length! :)
This commit is contained in:
parent
2ff7dbb41a
commit
6750a98bdd
@ -40,7 +40,6 @@ result = whisper.decode(model, mel, options)
|
||||
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
|
||||
init(autoreset=True) # Initialize colorama
|
||||
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
||||
# token_probs = token_probs[-len(text_tokens):]
|
||||
|
||||
output_text = ""
|
||||
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
||||
|
||||
@ -750,9 +750,6 @@ class DecodingTask:
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
# fix token_probs length
|
||||
token_probs = token_probs[-len(tokens):]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
@ -785,7 +782,7 @@ class DecodingTask:
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
token_probs=token_probs
|
||||
token_probs=token_probs[-len(tokens):]
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip(
|
||||
*fields
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user