mirror of
https://github.com/openai/whisper.git
synced 2025-11-24 06:26:03 +00:00
support old triton versions
This commit is contained in:
parent
31ee44b071
commit
6335afea82
@ -60,8 +60,9 @@ def median_kernel(filter_width: int):
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
triton_3_kernel_update = hasattr(kernel, "_unsafe_update_src")
|
||||
|
||||
kernel._unsafe_update_src(kernel.src.replace(
|
||||
new_kernel = kernel.src.replace(
|
||||
" LOAD_ALL_ROWS_HERE",
|
||||
"\n".join(
|
||||
[
|
||||
@ -69,8 +70,13 @@ def median_kernel(filter_width: int):
|
||||
for i in range(filter_width)
|
||||
]
|
||||
),
|
||||
))
|
||||
kernel._unsafe_update_src(kernel.src.replace(
|
||||
)
|
||||
if triton_3_kernel_update is True:
|
||||
kernel._unsafe_update_src(new_kernel)
|
||||
else:
|
||||
kernel.src = new_kernel
|
||||
|
||||
new_kernel = kernel.src.replace(
|
||||
" BUBBLESORT_HERE",
|
||||
"\n\n".join(
|
||||
[
|
||||
@ -90,9 +96,20 @@ def median_kernel(filter_width: int):
|
||||
for i in range(filter_width // 2 + 1)
|
||||
]
|
||||
),
|
||||
))
|
||||
kernel._unsafe_update_src(kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}"))
|
||||
kernel.hash = None
|
||||
)
|
||||
if triton_3_kernel_update is True:
|
||||
kernel._unsafe_update_src(new_kernel)
|
||||
else:
|
||||
kernel.src = new_kernel
|
||||
|
||||
new_kernel = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
if triton_3_kernel_update is True:
|
||||
kernel._unsafe_update_src(new_kernel)
|
||||
kernel.hash = None
|
||||
else:
|
||||
kernel.src = new_kernel
|
||||
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user