diff --git a/requirements.txt b/requirements.txt index 8ee5920..9aea844 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tqdm more-itertools tiktoken triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2" +packaging diff --git a/whisper/__init__.py b/whisper/__init__.py index f284ec0..cae01b3 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -4,6 +4,7 @@ import os import urllib import warnings from typing import List, Optional, Union +from packaging import version import torch from tqdm import tqdm @@ -147,7 +148,7 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {} + kwargs = {"weights_only": True} if version.parse(torch.__version__ ) >= version.parse("1.13") else {} checkpoint = torch.load(fp, map_location=device, **kwargs) del checkpoint_file