mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Fix: compute_device name in Whisper model
This commit is contained in:
parent
b0cf21b9b5
commit
e1545f4776
@ -157,7 +157,7 @@ def load_model(
|
|||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims, device=torch.device(device))
|
model = Whisper(dims, compute_device=torch.device(device))
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user