From 3211024b5386e1e4c191029272a682d904cb2267 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 30 Sep 2024 10:23:39 -0700 Subject: [PATCH] Update model.py --- whisper/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 9e09a6d..e537447 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -2,7 +2,7 @@ import base64 import gzip from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Tuple import numpy as np import torch @@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module): def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)