mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Merge bf24f74a349445487d53f58cbacc8d17fdb10c38 into c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
This commit is contained in:
commit
7a48d668b5
@ -5,3 +5,4 @@ tqdm
|
|||||||
more-itertools
|
more-itertools
|
||||||
tiktoken
|
tiktoken
|
||||||
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
||||||
|
packaging
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -147,7 +148,7 @@ def load_model(
|
|||||||
with (
|
with (
|
||||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
) as fp:
|
) as fp:
|
||||||
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
|
kwargs = {"weights_only": True} if version.parse(torch.__version__ ) >= version.parse("1.13") else {}
|
||||||
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
checkpoint = torch.load(fp, map_location=device, **kwargs)
|
||||||
del checkpoint_file
|
del checkpoint_file
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user