Fix attention caching to make it actually work (#370)

This commit is contained in:
Vicki Anand 2022-10-19 19:44:03 -04:00 committed by GitHub
parent 7f3e408e09
commit 9f70a352f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
): ):
q = self.query(x) q = self.query(x)
if kv_cache is None or xa is None: if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual. # otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa) k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa) v = self.value(x if xa is None else xa)
else: else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa)) k = kv_cache[self.key]
v = kv_cache.get(self.value, self.value(xa)) v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask) wv = self.qkv_attention(q, k, v, mask)
return self.out(wv) return self.out(wv)