mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Update model.py
This commit is contained in:
parent
65a353771a
commit
3211024b53
@ -2,7 +2,7 @@ import base64
|
||||
import gzip
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user