Update model.py

This commit is contained in:
Jong Wook Kim 2024-09-30 10:23:39 -07:00 committed by GitHub
parent 65a353771a
commit 3211024b53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@ import base64
import gzip import gzip
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, Optional from typing import Dict, Iterable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module):
def qkv_attention( def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25 scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)