mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge branch 'main' into jongwook/test-versions-update
This commit is contained in:
commit
9ff12f3825
@ -1,7 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,6 +13,14 @@ from .decoding import decode as decode_function
|
|||||||
from .decoding import detect_language as detect_language_function
|
from .decoding import detect_language as detect_language_function
|
||||||
from .transcribe import transcribe as transcribe_function
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
SDPA_AVAILABLE = True
|
||||||
|
except (ImportError, RuntimeError, OSError):
|
||||||
|
scaled_dot_product_attention = None
|
||||||
|
SDPA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelDimensions:
|
class ModelDimensions:
|
||||||
@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = True
|
||||||
|
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
@ -92,20 +113,30 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
):
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
qk = q @ k
|
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||||
|
a = scaled_dot_product_attention(
|
||||||
|
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||||
|
)
|
||||||
|
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = None
|
||||||
|
else:
|
||||||
|
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
qk = qk.float()
|
qk = qk.float()
|
||||||
|
|
||||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = qk.detach()
|
||||||
|
|
||||||
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
|||||||
@ -191,7 +191,9 @@ def find_alignment(
|
|||||||
for i, block in enumerate(model.decoder.blocks)
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
]
|
]
|
||||||
|
|
||||||
with torch.no_grad():
|
from .model import disable_sdpa
|
||||||
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user