using sdpa if available

This commit is contained in:
Jong Wook Kim 2024-09-30 02:38:26 -07:00
parent 423492dda7
commit 65a353771a
2 changed files with 43 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import base64
import gzip
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
try:
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
@dataclass
class ModelDimensions:
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out, qk
class ResidualAttentionBlock(nn.Module):

View File

@ -191,7 +191,9 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
from .model import disable_sdpa
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)