Merge 578c01b83ccc947b86e3551fe7939278c1f3747d into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2

This commit is contained in:
ChadF 2025-07-13 01:21:15 +00:00 committed by GitHub
commit 91a4dd478f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import io
import os import os
import urllib import urllib
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Union, BinaryIO
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = {
} }
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: # hashlib.file_digest() added in Python 3.11
if not hasattr(hashlib, 'file_digest'):
def _file_digest(file: BinaryIO, algo: str):
d = hashlib.new(algo)
while True:
buf = file.read(65536)
if not buf:
break
d.update(buf)
return d
hashlib.file_digest = _file_digest
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.isfile(download_target): if os.path.isfile(download_target):
with open(download_target, "rb") as f: with open(download_target, "rb") as f:
model_bytes = f.read() if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return download_target
return model_bytes if in_memory else download_target
else:
warnings.warn( warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
) )
@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
model_bytes = open(download_target, "rb").read() with open(download_target, "rb") as f:
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
raise RuntimeError( raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
) )
return model_bytes if in_memory else download_target return download_target
def available_models() -> List[str]: def available_models() -> List[str]:
@ -134,22 +151,25 @@ def load_model(
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS: if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory) checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name] alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name): elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name checkpoint_file = name
alignment_heads = None alignment_heads = None
else: else:
raise RuntimeError( raise RuntimeError(
f"Model {name} not found; available models = {available_models()}" f"Model {name} not found; available models = {available_models()}"
) )
with ( fp = open(checkpoint_file, "rb")
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp: if in_memory:
with fp:
fp = io.BytesIO(fp.read())
with fp:
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {} kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs) checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims)