mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 07:48:45 +00:00
Fixed token_prob length! :)
This commit is contained in:
parent
2ff7dbb41a
commit
6750a98bdd
@ -40,7 +40,6 @@ result = whisper.decode(model, mel, options)
|
|||||||
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
|
def get_colored_text(tokens: List[int], token_probs: List[float], tokenizer, prompt: str=""):
|
||||||
init(autoreset=True) # Initialize colorama
|
init(autoreset=True) # Initialize colorama
|
||||||
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
||||||
# token_probs = token_probs[-len(text_tokens):]
|
|
||||||
|
|
||||||
output_text = ""
|
output_text = ""
|
||||||
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
for i, (token, prob) in enumerate(zip(text_tokens, token_probs)):
|
||||||
|
|||||||
@ -750,9 +750,6 @@ class DecodingTask:
|
|||||||
for s in tokens
|
for s in tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
# fix token_probs length
|
|
||||||
token_probs = token_probs[-len(tokens):]
|
|
||||||
|
|
||||||
# select the top-ranked sample in each group
|
# select the top-ranked sample in each group
|
||||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
@ -785,7 +782,7 @@ class DecodingTask:
|
|||||||
no_speech_prob=no_speech_prob,
|
no_speech_prob=no_speech_prob,
|
||||||
temperature=self.options.temperature,
|
temperature=self.options.temperature,
|
||||||
compression_ratio=compression_ratio(text),
|
compression_ratio=compression_ratio(text),
|
||||||
token_probs=token_probs
|
token_probs=token_probs[-len(tokens):]
|
||||||
)
|
)
|
||||||
for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip(
|
for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip(
|
||||||
*fields
|
*fields
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user