diff --git a/whisper/timing.py b/whisper/timing.py index e563414..2340000 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024): x_skew = x_skew.T.contiguous() cost = torch.ones(N + M + 2, M + 2) * np.inf cost[0, 0] = 0 - cost = cost.cuda() + cost = cost.to(x.device) trace = torch.zeros_like(cost, dtype=torch.int32) dtw_kernel[(1,)](