diff --git a/whisper/__init__.py b/whisper/__init__.py index ca918a9..f284ec0 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -147,10 +147,8 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - try: - checkpoint = torch.load(fp, map_location=device, weights_only=True) - except TypeError: # for compatibility with older torch - checkpoint = torch.load(fp, map_location=device) + kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {} + checkpoint = torch.load(fp, map_location=device, **kwargs) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"])