Add util checking hpu tensor

This commit is contained in:
PiotrBLL 2024-11-11 22:20:50 +01:00
parent e1d4b7b4d7
commit d36696f808

View File

@ -1,8 +1,13 @@
import torch
def get_x_hpu(x_numpy):
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
def get_x_hpu(x_numpy):
x_hpu = torch.from_numpy(x_numpy).to("hpu")
return x_hpu
def is_hpu_device(device: torch.device):
return device in (torch.device("hpu:0"), torch.device("hpu"))