diff --git a/whisper/model.py b/whisper/model.py index 9e09a6d..e537447 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -2,7 +2,7 @@ import base64 import gzip from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Tuple import numpy as np import torch @@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module): def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)