Avoid loading large models into host memory when in_memory is False.

This commit is contained in:
Chad Fraleigh 2025-07-06 18:28:58 -07:00
parent c0d2f624c0
commit 578c01b83c
No known key found for this signature in database
GPG Key ID: 2415C39758458A8F

View File

@ -3,7 +3,7 @@ import io
import os import os
import urllib import urllib
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Union, BinaryIO
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = {
} }
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: # hashlib.file_digest() added in Python 3.11
if not hasattr(hashlib, 'file_digest'):
def _file_digest(file: BinaryIO, algo: str):
d = hashlib.new(algo)
while True:
buf = file.read(65536)
if not buf:
break
d.update(buf)
return d
hashlib.file_digest = _file_digest
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.isfile(download_target): if os.path.isfile(download_target):
with open(download_target, "rb") as f: with open(download_target, "rb") as f:
model_bytes = f.read() if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return download_target
return model_bytes if in_memory else download_target
else:
warnings.warn( warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
) )
@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
model_bytes = open(download_target, "rb").read() with open(download_target, "rb") as f:
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
raise RuntimeError( raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
) )
return model_bytes if in_memory else download_target return download_target
def available_models() -> List[str]: def available_models() -> List[str]:
@ -134,22 +151,25 @@ def load_model(
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS: if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory) checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name] alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name): elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name checkpoint_file = name
alignment_heads = None alignment_heads = None
else: else:
raise RuntimeError( raise RuntimeError(
f"Model {name} not found; available models = {available_models()}" f"Model {name} not found; available models = {available_models()}"
) )
with ( fp = open(checkpoint_file, "rb")
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: if in_memory:
with fp:
fp = io.BytesIO(fp.read())
with fp:
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {} kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs) checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims)