From 578c01b83ccc947b86e3551fe7939278c1f3747d Mon Sep 17 00:00:00 2001 From: Chad Fraleigh Date: Sun, 6 Jul 2025 18:28:58 -0700 Subject: [PATCH] Avoid loading large models into host memory when in_memory is False. --- whisper/__init__.py | 56 ++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index f284ec0..9c7a0fd 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -3,7 +3,7 @@ import io import os import urllib import warnings -from typing import List, Optional, Union +from typing import List, Optional, Union, BinaryIO import torch from tqdm import tqdm @@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = { } -def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: +# hashlib.file_digest() added in Python 3.11 +if not hasattr(hashlib, 'file_digest'): + def _file_digest(file: BinaryIO, algo: str): + d = hashlib.new(algo) + + while True: + buf = file.read(65536) + + if not buf: + break + + d.update(buf) + + return d + + hashlib.file_digest = _file_digest + + +def _download(url: str, root: str) -> str: os.makedirs(root, exist_ok=True) expected_sha256 = url.split("/")[-2] @@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if os.path.isfile(download_target): with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes if in_memory else download_target - else: + if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256: + return download_target + warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" ) @@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: output.write(buffer) loop.update(len(buffer)) - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) + with open(download_target, "rb") as f: + if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) - return model_bytes if in_memory else download_target + return download_target def available_models() -> List[str]: @@ -134,22 +151,25 @@ def load_model( download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root, in_memory) + checkpoint_file = _download(_MODELS[name], download_root) alignment_heads = _ALIGNMENT_HEADS[name] elif os.path.isfile(name): - checkpoint_file = open(name, "rb").read() if in_memory else name + checkpoint_file = name alignment_heads = None else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}" ) - with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") - ) as fp: + fp = open(checkpoint_file, "rb") + + if in_memory: + with fp: + fp = io.BytesIO(fp.read()) + + with fp: kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {} checkpoint = torch.load(fp, map_location=device, **kwargs) - del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims)