Avoid loading large models into host memory when in_memory is False.

This commit is contained in:
Chad Fraleigh 2025-07-06 18:28:58 -07:00
parent c0d2f624c0
commit 578c01b83c
No known key found for this signature in database
GPG Key ID: 2415C39758458A8F

View File

@ -3,7 +3,7 @@ import io
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, BinaryIO
import torch
from tqdm import tqdm
@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = {
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
# hashlib.file_digest() added in Python 3.11
if not hasattr(hashlib, 'file_digest'):
def _file_digest(file: BinaryIO, algo: str):
d = hashlib.new(algo)
while True:
buf = file.read(65536)
if not buf:
break
d.update(buf)
return d
hashlib.file_digest = _file_digest
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
return download_target
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
with open(download_target, "rb") as f:
if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
return download_target
def available_models() -> List[str]:
@ -134,22 +151,25 @@ def load_model(
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
checkpoint_file = name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
fp = open(checkpoint_file, "rb")
if in_memory:
with fp:
fp = io.BytesIO(fp.read())
with fp:
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)