Fix bug: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
This commit is contained in:
Michael Monashev 2022-10-17 21:38:20 +03:00 committed by GitHub
parent d18e9ea5dd
commit f680570016
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim