diff --git a/whisper/hpu_utils.py b/whisper/hpu_utils.py index 5ec464a..e00077c 100644 --- a/whisper/hpu_utils.py +++ b/whisper/hpu_utils.py @@ -1,8 +1,13 @@ import torch -from habana_frameworks.torch.utils.library_loader import load_habana_module - -load_habana_module() def get_x_hpu(x_numpy): + from habana_frameworks.torch.utils.library_loader import load_habana_module + + load_habana_module() + x_hpu = torch.from_numpy(x_numpy).to("hpu") return x_hpu + + +def is_hpu_device(device: torch.device): + return device in (torch.device("hpu:0"), torch.device("hpu"))