diff --git a/whisper/__init__.py b/whisper/__init__.py index 3c027d7..55f461a 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -157,7 +157,7 @@ def load_model( del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) - model = Whisper(dims, device=torch.device(device)) + model = Whisper(dims, compute_device=torch.device(device)) model.load_state_dict(checkpoint["model_state_dict"]) if alignment_heads is not None: