mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
fix bugs
This commit is contained in:
parent
0f4c4e5d45
commit
34ec5c81e7
@ -247,7 +247,7 @@ class TokenDecoder:
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]], Sequence[Sequence[Tensor]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
@ -265,6 +265,8 @@ class TokenDecoder:
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
token_probs : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing the probability of each token in the candidate sequences
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -304,8 +306,8 @@ class GreedyDecoder(TokenDecoder):
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
token_probs = F.pad(token_probs, (0, 1), value=0) # 0 ok?
|
||||
return tokens, sum_logprobs.tolist(), token_probs.tolist()
|
||||
token_probs = F.pad(token_probs, (0, 1), value=-1) # -1 to indicate the end of the sequence
|
||||
return tokens, sum_logprobs.tolist(), token_probs
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
@ -747,25 +749,32 @@ class DecodingTask:
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
token_probs = token_probs.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
|
||||
token_probs: List[List[Tensor]] = [
|
||||
[p[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t, p in zip(s_t, s_p)]
|
||||
for s_t, s_p in zip(tokens, token_probs)
|
||||
]
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
token_probs: List[List[float]] = [t[i].tolist() for i, t in zip(selected, token_probs)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user