Fix attention caching to make it actually work (#370)

This commit is contained in:
Vicki Anand 2022-10-19 19:44:03 -04:00 committed by GitHub
parent 7f3e408e09
commit 9f70a352f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
):
q = self.query(x)
if kv_cache is None or xa is None:
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
k = kv_cache[self.key]
v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)