Add weights_only parameter to load_model function and extent docstring

This commit is contained in:
Ultr4_dev 2024-08-13 00:02:11 +02:00
parent 2317050239
commit 895e4fb88e

View File

@ -101,6 +101,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
weights_only: bool = False,
) -> Whisper: ) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@ -116,6 +117,8 @@ def load_model(
path to download the model files; by default, it uses "~/.cache/whisper" path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool in_memory: bool
whether to preload the model weights into host memory whether to preload the model weights into host memory
weights_only: bool
whether to load only the model weights
Returns Returns
------- -------
@ -143,7 +146,7 @@ def load_model(
with ( with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: ) as fp:
checkpoint = torch.load(fp, map_location=device, weights_only=True) checkpoint = torch.load(fp, map_location=device, weights_only=weights_only)
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])