mirror of
https://github.com/openai/whisper.git
synced 2025-11-27 15:54:00 +00:00
Fixed triton kernel update to support latest triton versions (#2588)
* Update triton kernel using _unsafe_update_src * support old triton versions * refactored changes to update triton kernel only once * Update triton_ops.py --------- Co-authored-by: Jong Wook Kim <jongwook@openai.com> Co-authored-by: Jong Wook Kim <ilikekjw@gmail.com>
This commit is contained in:
parent
5dff4db81a
commit
86899243e9
@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
|
|||||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
kernel = triton.JITFunction(kernel.fn)
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
kernel.src = kernel.src.replace(
|
new_kernel = kernel.src.replace(
|
||||||
" LOAD_ALL_ROWS_HERE",
|
" LOAD_ALL_ROWS_HERE",
|
||||||
"\n".join(
|
"\n".join(
|
||||||
[
|
[
|
||||||
@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace(
|
|
||||||
|
new_kernel = new_kernel.replace(
|
||||||
" BUBBLESORT_HERE",
|
" BUBBLESORT_HERE",
|
||||||
"\n\n".join(
|
"\n\n".join(
|
||||||
[
|
[
|
||||||
@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
|
||||||
|
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
|
|
||||||
|
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||||
|
kernel._unsafe_update_src(new_kernel)
|
||||||
|
kernel.hash = None
|
||||||
|
else:
|
||||||
|
kernel.src = new_kernel
|
||||||
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user