Avoid keeping redundant copies of model weights in memory during load (#42)

* don't keep copies of model weights in host memory

* adding type annotation

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
Niklas K 2022-09-23 05:57:39 +02:00 committed by GitHub
parent a4fe05aa71
commit f296bcd3fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,12 +27,11 @@ _MODELS = {
} }
def _download(url: str, root: str) -> bytes: def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename) download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target): if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file") raise RuntimeError(f"{download_target} exists and is not a regular file")
@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes:
if os.path.isfile(download_target): if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read() model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes return model_bytes if in_memory else download_target
else: else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes:
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
return model_bytes return model_bytes if in_memory else download_target
def available_models() -> List[str]: def available_models() -> List[str]:
@ -66,7 +65,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper: def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
the PyTorch device to put the model into the PyTorch device to put the model into
download_root: str download_root: str
path to download the model files; by default, it uses "~/.cache/whisper" path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns Returns
------- -------
model : Whisper model : Whisper
The Whisper ASR model instance The Whisper ASR model instance
""" """
if name in _MODELS:
model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
elif os.path.isfile(name):
model_bytes = open(name, "rb").read()
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with io.BytesIO(model_bytes) as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
state_dict = checkpoint["model_state_dict"]
model = Whisper(dims)
model.load_state_dict(state_dict)
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device) return model.to(device)