From d36696f808c25e18933fe093db02e38cac968e26 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Mon, 11 Nov 2024 22:20:50 +0100 Subject: [PATCH] Add util checking hpu tensor --- whisper/hpu_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/whisper/hpu_utils.py b/whisper/hpu_utils.py index 5ec464a..e00077c 100644 --- a/whisper/hpu_utils.py +++ b/whisper/hpu_utils.py @@ -1,8 +1,13 @@ import torch -from habana_frameworks.torch.utils.library_loader import load_habana_module - -load_habana_module() def get_x_hpu(x_numpy): + from habana_frameworks.torch.utils.library_loader import load_habana_module + + load_habana_module() + x_hpu = torch.from_numpy(x_numpy).to("hpu") return x_hpu + + +def is_hpu_device(device: torch.device): + return device in (torch.device("hpu:0"), torch.device("hpu"))