mirror of
https://github.com/openai/whisper.git
synced 2025-11-28 00:03:40 +00:00
Fix infinite loop caused by incorrect timestamp tokens prediction (#914)
* Fix infinite loop caused by incorrect timestamp tokens prediction https://github.com/openai/whisper/discussions/810 * Update decoding.py --------- Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
5c1a8c10e7
commit
7858aa9c08
@ -412,7 +412,8 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
|
|
||||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
for k in range(tokens.shape[0]):
|
for k in range(tokens.shape[0]):
|
||||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
sampled_tokens = tokens[k, self.sample_begin :]
|
||||||
|
seq = [t for t in sampled_tokens.tolist()]
|
||||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||||
|
|
||||||
@ -422,6 +423,11 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
else: # cannot be normal text tokens
|
else: # cannot be normal text tokens
|
||||||
logits[k, : self.tokenizer.eot] = -np.inf
|
logits[k, : self.tokenizer.eot] = -np.inf
|
||||||
|
|
||||||
|
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
|
||||||
|
if timestamps.numel() > 0:
|
||||||
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||||
|
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
||||||
|
|
||||||
if tokens.shape[1] == self.sample_begin:
|
if tokens.shape[1] == self.sample_begin:
|
||||||
# suppress generating non-timestamp tokens at the beginning
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user