mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
using sdpa if available
This commit is contained in:
parent
423492dda7
commit
65a353771a
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import gzip
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
try:
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
SDPA_AVAILABLE = True
|
||||
except (ImportError, RuntimeError, OSError):
|
||||
scaled_dot_product_attention = None
|
||||
SDPA_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_sdpa():
|
||||
prev_state = MultiHeadAttention.use_sdpa
|
||||
try:
|
||||
MultiHeadAttention.use_sdpa = False
|
||||
yield
|
||||
finally:
|
||||
MultiHeadAttention.use_sdpa = prev_state
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
use_sdpa = True
|
||||
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||
a = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||
)
|
||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = None
|
||||
else:
|
||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = qk.detach()
|
||||
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
@ -191,7 +191,9 @@ def find_alignment(
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
from .model import disable_sdpa
|
||||
|
||||
with torch.no_grad(), disable_sdpa():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user