mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Avoid rearranging all caches (#1483)
* avoid rearranging all kv_caches * avoid calculating the same kv_cache from cross attn * Update decoding.py * linter fix --------- Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
f572f2161b
commit
b91c907694
@ -146,6 +146,10 @@ class PyTorchInference(Inference):
|
|||||||
self.kv_cache = {}
|
self.kv_cache = {}
|
||||||
self.hooks = []
|
self.hooks = []
|
||||||
|
|
||||||
|
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||||
|
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||||
|
self.kv_modules = key_modules + value_modules
|
||||||
|
|
||||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
if not self.kv_cache:
|
if not self.kv_cache:
|
||||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
@ -164,9 +168,10 @@ class PyTorchInference(Inference):
|
|||||||
self.hooks = []
|
self.hooks = []
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices):
|
def rearrange_kv_cache(self, source_indices):
|
||||||
for module, tensor in self.kv_cache.items():
|
if source_indices != list(range(len(source_indices))):
|
||||||
# update the key/value cache to contain the selected sequences
|
for module in self.kv_modules:
|
||||||
self.kv_cache[module] = tensor[source_indices].detach()
|
# update the key/value cache to contain the selected sequences
|
||||||
|
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
||||||
|
|
||||||
|
|
||||||
class SequenceRanker:
|
class SequenceRanker:
|
||||||
@ -668,7 +673,6 @@ class DecodingTask:
|
|||||||
return languages, lang_probs
|
return languages, lang_probs
|
||||||
|
|
||||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||||
assert audio_features.shape[0] == tokens.shape[0]
|
|
||||||
n_batch = tokens.shape[0]
|
n_batch = tokens.shape[0]
|
||||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
no_speech_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
@ -721,8 +725,7 @@ class DecodingTask:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||||
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
|
||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user