mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
committed
This commit is contained in:
parent
086108095a
commit
35530894b4
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,3 +9,6 @@ thumbs.db
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
samples/
|
||||||
|
|||||||
57
examples/confidence_per_token.py
Normal file
57
examples/confidence_per_token.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# IMPORTANT: This is just for using the local whisper dir as the package directly. Delete until next comment when just installing whisper normally.
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||||
|
# end of dev import
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
import colorsys
|
||||||
|
from typing import List
|
||||||
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
from colorama import init, Style
|
||||||
|
|
||||||
|
|
||||||
|
print('Loading model')
|
||||||
|
model = whisper.load_model("large")
|
||||||
|
|
||||||
|
|
||||||
|
print('Loading audio') # load audio and pad/trim it to fit 30 seconds
|
||||||
|
audio = whisper.load_audio("samples/your_audio.wav")
|
||||||
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
|
||||||
|
|
||||||
|
mel = whisper.log_mel_spectrogram(audio).to(model.device) # make log-Mel spectrogram and move to the same device as the model
|
||||||
|
|
||||||
|
|
||||||
|
detect_lang = False
|
||||||
|
language = "en"
|
||||||
|
if detect_lang: # detect the spoken language
|
||||||
|
print('Detecting language')
|
||||||
|
_, probs = model.detect_language(mel)
|
||||||
|
print(f"Detected language: {max(probs, key=probs.get)}")
|
||||||
|
language=max(probs, key=probs.get)
|
||||||
|
|
||||||
|
|
||||||
|
print('Decoding audio') # decode the audio
|
||||||
|
options = whisper.DecodingOptions()
|
||||||
|
result = whisper.decode(model, mel, options)
|
||||||
|
|
||||||
|
|
||||||
|
def print_colored_text(tokens: List[int], token_probs: List[float], tokenizer):
|
||||||
|
init(autoreset=True) # Initialize colorama
|
||||||
|
text_tokens = [tokenizer.decode([t]) for t in tokens]
|
||||||
|
|
||||||
|
for token, prob in zip(text_tokens, token_probs):
|
||||||
|
# Interpolate between red and green in the HSV color space
|
||||||
|
r, g, b = colorsys.hsv_to_rgb(prob * (1/3), 1, 1)
|
||||||
|
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
||||||
|
color_code = f"\033[38;2;{r};{g};{b}m"
|
||||||
|
|
||||||
|
colored_token = f"{color_code}{Style.BRIGHT}{token}{Style.RESET_ALL}"
|
||||||
|
print(colored_token, end="")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(multilingual=model.is_multilingual, language=language, task=options.task)
|
||||||
|
print_colored_text(result.tokens, result.token_probs, tokenizer) # print text with fancy confidence colors
|
||||||
@ -120,6 +120,7 @@ class DecodingResult:
|
|||||||
language: str
|
language: str
|
||||||
language_probs: Optional[Dict[str, float]] = None
|
language_probs: Optional[Dict[str, float]] = None
|
||||||
tokens: List[int] = field(default_factory=list)
|
tokens: List[int] = field(default_factory=list)
|
||||||
|
token_probs: List[float] = field(default_factory=list)
|
||||||
text: str = ""
|
text: str = ""
|
||||||
avg_logprob: float = np.nan
|
avg_logprob: float = np.nan
|
||||||
no_speech_prob: float = np.nan
|
no_speech_prob: float = np.nan
|
||||||
@ -218,7 +219,7 @@ class TokenDecoder:
|
|||||||
"""Initialize any stateful variables for decoding a new sequence"""
|
"""Initialize any stateful variables for decoding a new sequence"""
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||||
) -> Tuple[Tensor, bool]:
|
) -> Tuple[Tensor, bool]:
|
||||||
"""Specify how to select the next token, based on the current trace and logits
|
"""Specify how to select the next token, based on the current trace and logits
|
||||||
|
|
||||||
@ -245,7 +246,7 @@ class TokenDecoder:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, tokens: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||||
"""Finalize search and return the final candidate sequences
|
"""Finalize search and return the final candidate sequences
|
||||||
|
|
||||||
@ -275,7 +276,7 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
self.eot = eot
|
self.eot = eot
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||||
) -> Tuple[Tensor, bool]:
|
) -> Tuple[Tensor, bool]:
|
||||||
if self.temperature == 0:
|
if self.temperature == 0:
|
||||||
next_tokens = logits.argmax(dim=-1)
|
next_tokens = logits.argmax(dim=-1)
|
||||||
@ -283,19 +284,28 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||||
|
|
||||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
probs = torch.exp(logprobs)
|
||||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
|
|
||||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
completed = (tokens[:, -1] == self.eot).all()
|
current_token_probs = probs[torch.arange(probs.shape[0]), next_tokens]
|
||||||
return tokens, completed
|
token_probs = torch.cat([token_probs, current_token_probs[:, None]], dim=-1)
|
||||||
|
|
||||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
# token_logits = torch.stack([logits[k, next_tokens[k]] for k in range(next_tokens .shape[0])], dim=0)
|
||||||
|
# or use logprobs, the log softmax of the logits
|
||||||
|
# return it along with tokens and completed
|
||||||
|
|
||||||
|
completed = (tokens[:, -1] == self.eot).all()
|
||||||
|
return tokens, completed, token_probs
|
||||||
|
|
||||||
|
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||||
# make sure each sequence has at least one EOT token at the end
|
# make sure each sequence has at least one EOT token at the end
|
||||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
return tokens, sum_logprobs.tolist()
|
token_probs = F.pad(token_probs, (0, 1), value=0) # 0 ok?
|
||||||
|
return tokens, sum_logprobs.tolist(), token_probs.tolist()
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchDecoder(TokenDecoder):
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
@ -381,7 +391,7 @@ class BeamSearchDecoder(TokenDecoder):
|
|||||||
)
|
)
|
||||||
return tokens, completed
|
return tokens, completed
|
||||||
|
|
||||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||||
sum_logprobs = sum_logprobs.cpu()
|
sum_logprobs = sum_logprobs.cpu()
|
||||||
for i, sequences in enumerate(self.finished_sequences):
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
@ -682,6 +692,8 @@ class DecodingTask:
|
|||||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
|
token_probs = torch.zeros_like(tokens).float()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(self.sample_len):
|
for i in range(self.sample_len):
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
@ -700,14 +712,14 @@ class DecodingTask:
|
|||||||
logit_filter.apply(logits, tokens)
|
logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
# expand the tokens tensor with the selected next tokens
|
# expand the tokens tensor with the selected next tokens
|
||||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
tokens, completed, token_probs = self.decoder.update(tokens, logits, sum_logprobs, token_probs)
|
||||||
|
|
||||||
if completed or tokens.shape[-1] > self.n_ctx:
|
if completed or tokens.shape[-1] > self.n_ctx:
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
self.inference.cleanup_caching()
|
self.inference.cleanup_caching()
|
||||||
|
|
||||||
return tokens, sum_logprobs, no_speech_probs
|
return tokens, sum_logprobs, no_speech_probs, token_probs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||||
@ -734,7 +746,7 @@ class DecodingTask:
|
|||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
audio_features = audio_features[:: self.n_group]
|
audio_features = audio_features[:: self.n_group]
|
||||||
@ -745,7 +757,7 @@ class DecodingTask:
|
|||||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
|
|
||||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
|
||||||
tokens: List[List[Tensor]] = [
|
tokens: List[List[Tensor]] = [
|
||||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||||
for s in tokens
|
for s in tokens
|
||||||
@ -768,6 +780,7 @@ class DecodingTask:
|
|||||||
audio_features,
|
audio_features,
|
||||||
avg_logprobs,
|
avg_logprobs,
|
||||||
no_speech_probs,
|
no_speech_probs,
|
||||||
|
token_probs
|
||||||
)
|
)
|
||||||
if len(set(map(len, fields))) != 1:
|
if len(set(map(len, fields))) != 1:
|
||||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||||
@ -782,8 +795,9 @@ class DecodingTask:
|
|||||||
no_speech_prob=no_speech_prob,
|
no_speech_prob=no_speech_prob,
|
||||||
temperature=self.options.temperature,
|
temperature=self.options.temperature,
|
||||||
compression_ratio=compression_ratio(text),
|
compression_ratio=compression_ratio(text),
|
||||||
|
token_probs=token_probs
|
||||||
)
|
)
|
||||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
for text, language, tokens, features, avg_logprob, no_speech_prob, token_probs in zip(
|
||||||
*fields
|
*fields
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user