This commit is contained in:
Nathanael Perraudin 2024-05-22 10:57:00 +02:00
parent 0f4c4e5d45
commit 34ec5c81e7

View File

@ -247,7 +247,7 @@ class TokenDecoder:
def finalize( def finalize(
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]], Sequence[Sequence[Tensor]]]:
"""Finalize search and return the final candidate sequences """Finalize search and return the final candidate sequences
Parameters Parameters
@ -265,6 +265,8 @@ class TokenDecoder:
sum_logprobs : List[List[float]], length = n_audio sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above sequence of cumulative log probabilities corresponding to the above
token_probs : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing the probability of each token in the candidate sequences
""" """
raise NotImplementedError raise NotImplementedError
@ -304,8 +306,8 @@ class GreedyDecoder(TokenDecoder):
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor): def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
# make sure each sequence has at least one EOT token at the end # make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot) tokens = F.pad(tokens, (0, 1), value=self.eot)
token_probs = F.pad(token_probs, (0, 1), value=0) # 0 ok? token_probs = F.pad(token_probs, (0, 1), value=-1) # -1 to indicate the end of the sequence
return tokens, sum_logprobs.tolist(), token_probs.tolist() return tokens, sum_logprobs.tolist(), token_probs
class BeamSearchDecoder(TokenDecoder): class BeamSearchDecoder(TokenDecoder):
@ -747,25 +749,32 @@ class DecodingTask:
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens) tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions # reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group] audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group] no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1) tokens = tokens.reshape(n_audio, self.n_group, -1)
token_probs = token_probs.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs) tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
token_probs: List[List[Tensor]] = [
[p[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t, p in zip(s_t, s_p)]
for s_t, s_p in zip(tokens, token_probs)
]
tokens: List[List[Tensor]] = [ tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens for s in tokens
] ]
# select the top-ranked sample in each group # select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs) selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
token_probs: List[List[float]] = [t[i].tolist() for i, t in zip(selected, token_probs)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]