diff --git a/whisper/model.py b/whisper/model.py index 1b5890f..b3b6844 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module): ): q = self.query(x) - if kv_cache is None or xa is None: + if kv_cache is None or xa is None or self.key not in kv_cache: # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; # otherwise, perform key/value projections for self- or cross-attention as usual. k = self.key(x if xa is None else xa) v = self.value(x if xa is None else xa) else: # for cross-attention, calculate keys and values once and reuse in subsequent calls. - k = kv_cache.get(self.key, self.key(xa)) - v = kv_cache.get(self.value, self.value(xa)) + k = kv_cache[self.key] + v = kv_cache[self.value] wv = self.qkv_attention(q, k, v, mask) return self.out(wv)