Fix: Ensure DTW cost tensor is on the same device as input tensor (#2561)

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
Nathan Harmon 2025-06-25 18:42:09 -06:00 committed by GitHub
parent f50c4f264e
commit 679ae1d141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
x_skew = x_skew.T.contiguous() x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0 cost[0, 0] = 0
cost = cost.cuda() cost = cost.to(x.device)
trace = torch.zeros_like(cost, dtype=torch.int32) trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)]( dtw_kernel[(1,)](