mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge 578c01b83ccc947b86e3551fe7939278c1f3747d into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
91a4dd478f
@ -3,7 +3,7 @@ import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, BinaryIO
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@ -51,7 +51,25 @@ _ALIGNMENT_HEADS = {
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
# hashlib.file_digest() added in Python 3.11
|
||||
if not hasattr(hashlib, 'file_digest'):
|
||||
def _file_digest(file: BinaryIO, algo: str):
|
||||
d = hashlib.new(algo)
|
||||
|
||||
while True:
|
||||
buf = file.read(65536)
|
||||
|
||||
if not buf:
|
||||
break
|
||||
|
||||
d.update(buf)
|
||||
|
||||
return d
|
||||
|
||||
hashlib.file_digest = _file_digest
|
||||
|
||||
|
||||
def _download(url: str, root: str) -> str:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
with open(download_target, "rb") as f:
|
||||
if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
return download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
@ -134,22 +151,25 @@ def load_model(
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
checkpoint_file = _download(_MODELS[name], download_root)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
checkpoint_file = name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
fp = open(checkpoint_file, "rb")
|
||||
|
||||
if in_memory:
|
||||
with fp:
|
||||
fp = io.BytesIO(fp.read())
|
||||
|
||||
with fp:
|
||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user