diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..8d14dbf 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -730,7 +730,8 @@ class DecodingTask: ) ] - # repeat text tensors by the group size, for beam search or best-of-n sampling + # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling + audio_features = audio_features.repeat_interleave(self.n_group, dim=0) tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop