mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Update model.py
This commit is contained in:
parent
65a353771a
commit
3211024b53
@ -2,7 +2,7 @@ import base64
|
|||||||
import gzip
|
import gzip
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
def qkv_attention(
|
def qkv_attention(
|
||||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user