diff --git a/whisper/__init__.py b/whisper/__init__.py index e596887..dd2629a 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -27,12 +27,11 @@ _MODELS = { } -def _download(url: str, root: str) -> bytes: +def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, filename) + download_target = os.path.join(root, os.path.basename(url)) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") @@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes: if os.path.isfile(download_target): model_bytes = open(download_target, "rb").read() if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes + return model_bytes if in_memory else download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") @@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes: if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") - return model_bytes + return model_bytes if in_memory else download_target def available_models() -> List[str]: @@ -66,7 +65,7 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper: +def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: """ Load a Whisper ASR model @@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow the PyTorch device to put the model into download_root: str path to download the model files; by default, it uses "~/.cache/whisper" + in_memory: bool + whether to preload the model weights into host memory Returns ------- model : Whisper The Whisper ASR model instance """ - if name in _MODELS: - model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper")) - elif os.path.isfile(name): - model_bytes = open(name, "rb").read() - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - with io.BytesIO(model_bytes) as fp: - checkpoint = torch.load(fp, map_location="cpu") - - dims = ModelDimensions(**checkpoint["dims"]) - state_dict = checkpoint["model_state_dict"] - model = Whisper(dims) - model.load_state_dict(state_dict) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" + if download_root is None: + download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper") + + if name in _MODELS: + checkpoint_file = _download(_MODELS[name], download_root, in_memory) + elif os.path.isfile(name): + checkpoint_file = open(name, "rb").read() if in_memory else name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) return model.to(device)