mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Avoid loading large models into host memory when in_memory is False.
This commit is contained in:
parent
c0d2f624c0
commit
578c01b83c
@ -3,7 +3,7 @@ import io
|
|||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, BinaryIO
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
# hashlib.file_digest() added in Python 3.11
|
||||||
|
if not hasattr(hashlib, 'file_digest'):
|
||||||
|
def _file_digest(file: BinaryIO, algo: str):
|
||||||
|
d = hashlib.new(algo)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
buf = file.read(65536)
|
||||||
|
|
||||||
|
if not buf:
|
||||||
|
break
|
||||||
|
|
||||||
|
d.update(buf)
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
hashlib.file_digest = _file_digest
|
||||||
|
|
||||||
|
|
||||||
|
def _download(url: str, root: str) -> str:
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
expected_sha256 = url.split("/")[-2]
|
||||||
@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
|
|
||||||
if os.path.isfile(download_target):
|
if os.path.isfile(download_target):
|
||||||
with open(download_target, "rb") as f:
|
with open(download_target, "rb") as f:
|
||||||
model_bytes = f.read()
|
if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
return download_target
|
||||||
return model_bytes if in_memory else download_target
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
)
|
)
|
||||||
@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
output.write(buffer)
|
output.write(buffer)
|
||||||
loop.update(len(buffer))
|
loop.update(len(buffer))
|
||||||
|
|
||||||
model_bytes = open(download_target, "rb").read()
|
with open(download_target, "rb") as f:
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_bytes if in_memory else download_target
|
return download_target
|
||||||
|
|
||||||
|
|
||||||
def available_models() -> List[str]:
|
def available_models() -> List[str]:
|
||||||
@ -134,22 +151,25 @@ def load_model(
|
|||||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
checkpoint_file = _download(_MODELS[name], download_root)
|
||||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||||
elif os.path.isfile(name):
|
elif os.path.isfile(name):
|
||||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
checkpoint_file = name
|
||||||
alignment_heads = None
|
alignment_heads = None
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Model {name} not found; available models = {available_models()}"
|
f"Model {name} not found; available models = {available_models()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with (
|
fp = open(checkpoint_file, "rb")
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
|
||||||
) as fp:
|
if in_memory:
|
||||||
|
with fp:
|
||||||
|
fp = io.BytesIO(fp.read())
|
||||||
|
|
||||||
|
with fp:
|
||||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
||||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
||||||
del checkpoint_file
|
|
||||||
|
|
||||||
dims = ModelDimensions(**checkpoint["dims"])
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
model = Whisper(dims)
|
model = Whisper(dims)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user