Fix: compute_device name in Whisper model

This commit is contained in:
PiotrBLL 2024-11-11 22:22:08 +01:00
parent b0cf21b9b5
commit e1545f4776

View File

@ -157,7 +157,7 @@ def load_model(
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, device=torch.device(device))
model = Whisper(dims, compute_device=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: