mirror of
https://github.com/openai/whisper.git
synced 2025-11-23 22:15:58 +00:00
Fix: Ensure DTW cost tensor is on the same device as input tensor (#2561)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
This commit is contained in:
parent
f50c4f264e
commit
679ae1d141
@ -117,7 +117,7 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
cost = cost.to(x.device)
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user