mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Fix attention caching to make it actually work (#370)
This commit is contained in:
parent
7f3e408e09
commit
9f70a352f9
@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None:
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache.get(self.key, self.key(xa))
|
||||
v = kv_cache.get(self.value, self.value(xa))
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv)
|
||||
|
Loading…
x
Reference in New Issue
Block a user