mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
fix bugs
This commit is contained in:
parent
0f4c4e5d45
commit
34ec5c81e7
@ -247,7 +247,7 @@ class TokenDecoder:
|
|||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor
|
||||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]], Sequence[Sequence[Tensor]]]:
|
||||||
"""Finalize search and return the final candidate sequences
|
"""Finalize search and return the final candidate sequences
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -265,6 +265,8 @@ class TokenDecoder:
|
|||||||
|
|
||||||
sum_logprobs : List[List[float]], length = n_audio
|
sum_logprobs : List[List[float]], length = n_audio
|
||||||
sequence of cumulative log probabilities corresponding to the above
|
sequence of cumulative log probabilities corresponding to the above
|
||||||
|
token_probs : Sequence[Sequence[Tensor]], length = n_audio
|
||||||
|
sequence of Tensors containing the probability of each token in the candidate sequences
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -304,8 +306,8 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
def finalize(self, tokens: Tensor, sum_logprobs: Tensor, token_probs: Tensor):
|
||||||
# make sure each sequence has at least one EOT token at the end
|
# make sure each sequence has at least one EOT token at the end
|
||||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
token_probs = F.pad(token_probs, (0, 1), value=0) # 0 ok?
|
token_probs = F.pad(token_probs, (0, 1), value=-1) # -1 to indicate the end of the sequence
|
||||||
return tokens, sum_logprobs.tolist(), token_probs.tolist()
|
return tokens, sum_logprobs.tolist(), token_probs
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchDecoder(TokenDecoder):
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
@ -747,25 +749,32 @@ class DecodingTask:
|
|||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
|
tokens, sum_logprobs, no_speech_probs, token_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
audio_features = audio_features[:: self.n_group]
|
audio_features = audio_features[:: self.n_group]
|
||||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||||
|
|
||||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||||
|
token_probs = token_probs.reshape(n_audio, self.n_group, -1)
|
||||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
|
|
||||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
|
tokens, sum_logprobs, token_probs = self.decoder.finalize(tokens, sum_logprobs, token_probs)
|
||||||
|
token_probs: List[List[Tensor]] = [
|
||||||
|
[p[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t, p in zip(s_t, s_p)]
|
||||||
|
for s_t, s_p in zip(tokens, token_probs)
|
||||||
|
]
|
||||||
tokens: List[List[Tensor]] = [
|
tokens: List[List[Tensor]] = [
|
||||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||||
for s in tokens
|
for s in tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# select the top-ranked sample in each group
|
# select the top-ranked sample in each group
|
||||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
|
|
||||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
|
token_probs: List[List[float]] = [t[i].tolist() for i, t in zip(selected, token_probs)]
|
||||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||||
|
|
||||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user