diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..1e870a3 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -125,6 +125,7 @@ class DecodingResult: no_speech_prob: float = np.nan temperature: float = np.nan compression_ratio: float = np.nan + tokens_probs: list[float] = field(default_factory=list) class Inference: @@ -218,8 +219,8 @@ class TokenDecoder: """Initialize any stateful variables for decoding a new sequence""" def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list[list[float]] + ) -> Tuple[Tensor, list, bool]: """Specify how to select the next token, based on the current trace and logits Parameters @@ -275,8 +276,8 @@ class GreedyDecoder(TokenDecoder): self.eot = eot def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list + ) -> Tuple[Tensor, list, bool]: if self.temperature == 0: next_tokens = logits.argmax(dim=-1) else: @@ -284,18 +285,25 @@ class GreedyDecoder(TokenDecoder): logprobs = F.log_softmax(logits.float(), dim=-1) current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + current_probs = torch.exp(current_logprobs) sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + tokens_probs = [t_p + [c_p.item()] for t_p, c_p in zip(tokens_probs, current_probs)] next_tokens[tokens[:, -1] == self.eot] = self.eot tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) completed = (tokens[:, -1] == self.eot).all() - return tokens, completed + + return tokens, tokens_probs, completed - def finalize(self, tokens: Tensor, sum_logprobs: Tensor): - # make sure each sequence has at least one EOT token at the end + def finalize( + self, tokens: Tensor, tokens_probs: list, sum_logprobs: Tensor + ) -> Tuple[Tensor, list, list]: tokens = F.pad(tokens, (0, 1), value=self.eot) - return tokens, sum_logprobs.tolist() + tokens_probs = [[ t + [1.0] for t in s] for s in tokens_probs] + return tokens, tokens_probs, sum_logprobs.tolist() + + class BeamSearchDecoder(TokenDecoder): @@ -321,37 +329,39 @@ class BeamSearchDecoder(TokenDecoder): self.finished_sequences = None def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list + ) -> Tuple[Tensor, list, bool]: if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") n_audio = tokens.shape[0] // self.beam_size - if self.finished_sequences is None: # for the first update + if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)] logprobs = F.log_softmax(logits.float(), dim=-1) next_tokens, source_indices, finished_sequences = [], [], [] for i in range(n_audio): - scores, sources, finished = {}, {}, {} + scores, sources, finished, probs = {}, {}, {}, {} - # STEP 1: calculate the cumulative log probabilities for possible candidates for j in range(self.beam_size): idx = i * self.beam_size + j prefix = tokens[idx].tolist() for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + prob = torch.exp(logprob).item() new_logprob = (sum_logprobs[idx] + logprob).item() sequence = tuple(prefix + [token.item()]) scores[sequence] = new_logprob sources[sequence] = idx - # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + probs[sequence] = tokens_probs[idx] + [prob] + saved = 0 for sequence in sorted(scores, key=scores.get, reverse=True): if sequence[-1] == self.eot: - finished[sequence] = scores[sequence] + finished[sequence] = (scores[sequence], probs[sequence]) else: sum_logprobs[len(next_tokens)] = scores[sequence] + tokens_probs[len(next_tokens)] = probs[sequence] next_tokens.append(sequence) source_indices.append(sources[sequence]) @@ -364,44 +374,42 @@ class BeamSearchDecoder(TokenDecoder): tokens = torch.tensor(next_tokens, device=tokens.device) self.inference.rearrange_kv_cache(source_indices) - # add newly finished sequences to self.finished_sequences assert len(self.finished_sequences) == len(finished_sequences) - for previously_finished, newly_finished in zip( - self.finished_sequences, finished_sequences - ): + for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): if len(previously_finished) >= self.max_candidates: - break # the candidate list is full + break previously_finished[seq] = newly_finished[seq] - # mark as completed if all audio has enough number of samples completed = all( - len(sequences) >= self.max_candidates - for sequences in self.finished_sequences + len(sequences) >= self.max_candidates for sequences in self.finished_sequences ) - return tokens, completed - def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): - # collect all finished sequences, including patience, and add unfinished ones if not enough + return tokens, tokens_probs, completed + + def finalize( + self, preceding_tokens: Tensor, preceding_tokens_prob: list, sum_logprobs: Tensor + ) -> Tuple[list, list, list]: sum_logprobs = sum_logprobs.cpu() for i, sequences in enumerate(self.finished_sequences): - if ( - len(sequences) < self.beam_size - ): # when not enough sequences are finished + if len(sequences) < self.beam_size: for j in list(np.argsort(sum_logprobs[i]))[::-1]: sequence = preceding_tokens[i, j].tolist() + [self.eot] - sequences[tuple(sequence)] = sum_logprobs[i][j].item() + sequences[tuple(sequence)] = (sum_logprobs[i][j].item(), preceding_tokens_prob[i][j] + [1.0]) if len(sequences) >= self.beam_size: break tokens: List[List[Tensor]] = [ - [torch.tensor(seq) for seq in sequences.keys()] - for sequences in self.finished_sequences + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences ] sum_logprobs: List[List[float]] = [ - list(sequences.values()) for sequences in self.finished_sequences + [v[0] for v in sequences.values()] for sequences in self.finished_sequences ] - return tokens, sum_logprobs + tokens_probs: list[list[list[float]]] = [ + [v[1] for v in sequences.values()] for sequences in self.finished_sequences + ] + + return tokens, tokens_probs, sum_logprobs class LogitFilter: @@ -700,7 +708,8 @@ class DecodingTask: logit_filter.apply(logits, tokens) # expand the tokens tensor with the selected next tokens - tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + tokens, tokens_probs, completed = self.decoder.update(tokens, logits, sum_logprobs, tokens_probs) + if completed or tokens.shape[-1] > self.n_ctx: break @@ -734,7 +743,7 @@ class DecodingTask: tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop - tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + tokens, sum_logprobs, no_speech_probs, tokens_probs = self._main_loop(audio_features, tokens) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] @@ -747,8 +756,10 @@ class DecodingTask: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens: List[List[Tensor]] = [ - [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] - for s in tokens + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens + ] + tokens_probs: list[list[list[float]]] = [ + [probs[:tokens.shape[0]]for probs, tokens in zip(s, t)] for s, t in zip(tokens_probs, tokens) ] # select the top-ranked sample in each group