mirror of
https://github.com/openai/whisper.git
synced 2025-07-07 12:12:30 +00:00
Avoid keeping redundant copies of model weights in memory during load (#42)
* don't keep copies of model weights in host memory * adding type annotation Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
parent
a4fe05aa71
commit
f296bcd3fa
@ -27,12 +27,11 @@ _MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str) -> bytes:
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
filename = os.path.basename(url)
|
|
||||||
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
expected_sha256 = url.split("/")[-2]
|
||||||
download_target = os.path.join(root, filename)
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes:
|
|||||||
if os.path.isfile(download_target):
|
if os.path.isfile(download_target):
|
||||||
model_bytes = open(download_target, "rb").read()
|
model_bytes = open(download_target, "rb").read()
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
return model_bytes
|
return model_bytes if in_memory else download_target
|
||||||
else:
|
else:
|
||||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
|
||||||
@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes:
|
|||||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||||
|
|
||||||
return model_bytes
|
return model_bytes if in_memory else download_target
|
||||||
|
|
||||||
|
|
||||||
def available_models() -> List[str]:
|
def available_models() -> List[str]:
|
||||||
@ -66,7 +65,7 @@ def available_models() -> List[str]:
|
|||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
|
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||||
"""
|
"""
|
||||||
Load a Whisper ASR model
|
Load a Whisper ASR model
|
||||||
|
|
||||||
@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
|||||||
the PyTorch device to put the model into
|
the PyTorch device to put the model into
|
||||||
download_root: str
|
download_root: str
|
||||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
in_memory: bool
|
||||||
|
whether to preload the model weights into host memory
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
model : Whisper
|
model : Whisper
|
||||||
The Whisper ASR model instance
|
The Whisper ASR model instance
|
||||||
"""
|
"""
|
||||||
if name in _MODELS:
|
|
||||||
model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
|
|
||||||
elif os.path.isfile(name):
|
|
||||||
model_bytes = open(name, "rb").read()
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
|
||||||
|
|
||||||
with io.BytesIO(model_bytes) as fp:
|
|
||||||
checkpoint = torch.load(fp, map_location="cpu")
|
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
|
||||||
state_dict = checkpoint["model_state_dict"]
|
|
||||||
model = Whisper(dims)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if download_root is None:
|
||||||
|
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||||
|
|
||||||
|
if name in _MODELS:
|
||||||
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||||
|
|
||||||
|
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||||
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
|
del checkpoint_file
|
||||||
|
|
||||||
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
model = Whisper(dims)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user