typing addition

This commit is contained in:
OliverCai0 2023-11-14 16:00:45 -06:00
parent 1cea435768
commit ee36decb1a

View File

@ -169,7 +169,7 @@ class PyTorchInference(Inference):
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
def rearrange_kv_cache(self, source_indices : List[int]):
if source_indices != list(range(len(source_indices))):
for module in self.kv_modules:
# update the key/value cache to contain the selected sequences