diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d0..312a9e0 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -169,7 +169,7 @@ class PyTorchInference(Inference): self.kv_cache = {} self.hooks = [] - def rearrange_kv_cache(self, source_indices): + def rearrange_kv_cache(self, source_indices : List[int]): if source_indices != list(range(len(source_indices))): for module in self.kv_modules: # update the key/value cache to contain the selected sequences