mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
Add weights_only parameter to load_model function and extent docstring
This commit is contained in:
parent
2317050239
commit
895e4fb88e
@ -101,6 +101,7 @@ def load_model(
|
|||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
|
weights_only: bool = False,
|
||||||
) -> Whisper:
|
) -> Whisper:
|
||||||
"""
|
"""
|
||||||
Load a Whisper ASR model
|
Load a Whisper ASR model
|
||||||
@ -116,6 +117,8 @@ def load_model(
|
|||||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
in_memory: bool
|
in_memory: bool
|
||||||
whether to preload the model weights into host memory
|
whether to preload the model weights into host memory
|
||||||
|
weights_only: bool
|
||||||
|
whether to load only the model weights
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -143,7 +146,7 @@ def load_model(
|
|||||||
with (
|
with (
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
checkpoint = torch.load(fp, map_location=device, weights_only=True)
|
checkpoint = torch.load(fp, map_location=device, weights_only=weights_only)
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user