Revert "saving the qk matrix in the attention module for convenience"

This reverts commit 68e44bd83ce6c3e352f74b266aa39d8b649af9e3.
This commit is contained in:
Jong Wook Kim 2022-12-29 23:53:31 -07:00
parent 68e44bd83c
commit 9323b2526c

View File

@ -62,7 +62,6 @@ class MultiHeadAttention(nn.Module):
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.last_qk = None
def forward(
self,
@ -97,8 +96,6 @@ class MultiHeadAttention(nn.Module):
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
self.last_qk = qk.detach()
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)