From 7858aa9c08d98f75575035ecd6481f462d66ca27 Mon Sep 17 00:00:00 2001 From: Andrey Chernykh Date: Thu, 2 Feb 2023 06:46:51 +0700 Subject: [PATCH] Fix infinite loop caused by incorrect timestamp tokens prediction (#914) * Fix infinite loop caused by incorrect timestamp tokens prediction https://github.com/openai/whisper/discussions/810 * Update decoding.py --------- Co-authored-by: Jong Wook Kim --- whisper/decoding.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index 983c898..7613a0c 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -412,7 +412,8 @@ class ApplyTimestampRules(LogitFilter): # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly for k in range(tokens.shape[0]): - seq = [t for t in tokens[k, self.sample_begin :].tolist()] + sampled_tokens = tokens[k, self.sample_begin :] + seq = [t for t in sampled_tokens.tolist()] last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin @@ -422,6 +423,11 @@ class ApplyTimestampRules(LogitFilter): else: # cannot be normal text tokens logits[k, : self.tokenizer.eot] = -np.inf + timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)] + if timestamps.numel() > 0: + # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last + logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf + if tokens.shape[1] == self.sample_begin: # suppress generating non-timestamp tokens at the beginning logits[:, : self.tokenizer.timestamp_begin] = -np.inf