diff --git a/examples/test_prob.py b/examples/test_prob.py index 130762b..f6beebd 100644 --- a/examples/test_prob.py +++ b/examples/test_prob.py @@ -48,6 +48,18 @@ def decode_audio(model, audio, language="en", f16=True): return text_tokens, result.token_probs +def get_colored_text(text_tokens: List[int], token_probs: List[float]): + init(autoreset=False) # Initialize colorama with autoreset=True to reset colors after each print + output_text = "" + for i, (token, prob) in enumerate(zip(text_tokens, token_probs)): + # Interpolate between red and green in the HSV color space + r, g, b = colorsys.hsv_to_rgb(prob * (1/3), 1, 1) + r, g, b = int(r * 255), int(g * 255), int(b * 255) + color_code = f"\033[38;2;{r};{g};{b}m" + colored_token = f"{color_code}{Style.BRIGHT}{str(token)}{Style.RESET_ALL}" + output_text += colored_token + return output_text + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--audio', type=str, help='the path of the audio file') @@ -63,5 +75,4 @@ if __name__ == '__main__': model = whisper.load_model(model) audio = load_audio_from_source(audio_source=audio) text, proba = decode_audio(model=model, audio=audio) - print(text) - print(proba) + print(get_colored_text(text, proba)) \ No newline at end of file