diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718..3ebb451 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -7,6 +7,7 @@ from typing import List, Optional, Union import torch from tqdm import tqdm +import subprocess from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language @@ -51,7 +52,9 @@ _ALIGNMENT_HEADS = { } -def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: + + +def _download(url: str, root: str, in_memory: bool, use_aria2: bool = True) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) expected_sha256 = url.split("/")[-2] @@ -61,39 +64,39 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() + model_bytes = open(download_target, "rb").read() if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return model_bytes if in_memory else download_target else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) + if use_aria2 and _is_aria2_installed(): + # Use aria2 for downloading with resume capability + subprocess.run(["aria2c", "-c", "-d", root, "-o", os.path.basename(url), url]) + else: + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + output.write(buffer) + loop.update(len(buffer)) model_bytes = open(download_target, "rb").read() if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") return model_bytes if in_memory else download_target +def _is_aria2_installed() -> bool: + try: + subprocess.run(["aria2c", "--version"], capture_output=True, check=True) + return True + except subprocess.CalledProcessError: + return False + + def available_models() -> List[str]: """Returns the names of available models"""