diff --git a/whisper/decoding.py b/whisper/decoding.py index ea44088..1922c42 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -247,7 +247,7 @@ class TokenDecoder: def finalize( self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor - ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: + ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]], Sequence[Sequence[Tensor]]]: """Finalize search and return the final candidate sequences Parameters @@ -265,6 +265,8 @@ class TokenDecoder: sum_logprobs : List[List[float]], length = n_audio sequence of cumulative log probabilities corresponding to the above + token_probs : Sequence[Sequence[Tensor]], length = n_audio + sequence of Tensors containing the probability of each token in the candidate sequences """ raise NotImplementedError @@ -304,8 +306,8 @@ class GreedyDecoder(TokenDecoder): def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor): # make sure each sequence has at least one EOT token at the end tokens = F.pad(tokens, (0, 1), value=self.eot) - token_probs = F.pad(token_probs, (0, 1), value=0) # 0 ok? - return tokens, sum_logprobs.tolist(), token_probs.tolist() + token_probs = F.pad(token_probs, (0, 1), value=-1) # -1 to indicate the end of the sequence + return tokens, sum_logprobs.tolist(), token_probs class BeamSearchDecoder(TokenDecoder): @@ -747,25 +749,32 @@ class DecodingTask: # call the main sampling loop tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens) - # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] no_speech_probs = no_speech_probs[:: self.n_group] assert audio_features.shape[0] == len(no_speech_probs) == n_audio tokens = tokens.reshape(n_audio, self.n_group, -1) + token_probs = token_probs.reshape(n_audio, self.n_group, -1) sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs) + token_probs: List[List[Tensor]] = [ + [p[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t, p in zip(s_t, s_p)] + for s_t, s_p in zip(tokens, token_probs) + ] tokens: List[List[Tensor]] = [ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens ] + # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + token_probs: List[List[float]] = [t[i].tolist() for i, t in zip(selected, token_probs)] texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]