mirror of
https://github.com/openai/whisper.git
synced 2025-03-30 14:28:27 +00:00
Merge c632345c1c5b2f630502a4f9c35d39b170513eac into 517a43ecd132a2089d85f4ebc044728a71d49f6e
This commit is contained in:
commit
cca29c39a0
@ -7,6 +7,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
@ -51,7 +52,9 @@ _ALIGNMENT_HEADS = {
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool, use_aria2: bool = True) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
@ -61,39 +64,39 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
if use_aria2 and _is_aria2_installed():
|
||||
# Use aria2 for downloading with resume capability
|
||||
subprocess.run(["aria2c", "-c", "-d", root, "-o", os.path.basename(url), url])
|
||||
else:
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
def _is_aria2_installed() -> bool:
|
||||
try:
|
||||
subprocess.run(["aria2c", "--version"], capture_output=True, check=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user