mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Fix attention caching to make it actually work (#370)
This commit is contained in:
parent
7f3e408e09
commit
9f70a352f9
@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
|
|
||||||
if kv_cache is None or xa is None:
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
k = self.key(x if xa is None else xa)
|
k = self.key(x if xa is None else xa)
|
||||||
v = self.value(x if xa is None else xa)
|
v = self.value(x if xa is None else xa)
|
||||||
else:
|
else:
|
||||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
k = kv_cache.get(self.key, self.key(xa))
|
k = kv_cache[self.key]
|
||||||
v = kv_cache.get(self.value, self.value(xa))
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
wv = self.qkv_attention(q, k, v, mask)
|
wv = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv)
|
return self.out(wv)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user