mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 14:35:57 +00:00
Merge 6cbc47dd44a03899a83709f955903b2f9e8a7c49 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
ec54c3a890
@ -52,7 +52,6 @@ _ALIGNMENT_HEADS = {
|
|||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
os.makedirs(root, exist_ok=True)
|
|
||||||
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
expected_sha256 = url.split("/")[-2]
|
||||||
download_target = os.path.join(root, os.path.basename(url))
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
@ -69,7 +68,19 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# If not already downloaded, check global locations and use
|
||||||
|
# those before downloading. Prefer /var/ for host specific
|
||||||
|
# override over /usr/ file controlled by package manager.
|
||||||
|
for globalroot in ('/var/lib/openai-whisper/', '/usr/share/openai-whisper/'):
|
||||||
|
candidate = os.path.join(globalroot, os.path.basename(url))
|
||||||
|
if os.path.isfile(candidate):
|
||||||
|
with open(candidate, "rb") as f:
|
||||||
|
model_bytes = f.read()
|
||||||
|
return model_bytes if in_memory else candidate
|
||||||
|
|
||||||
|
# Time to download, make sure download directory is available.
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
with tqdm(
|
with tqdm(
|
||||||
total=int(source.info().get("Content-Length")),
|
total=int(source.info().get("Content-Length")),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user