Fixed triton kernel update to support latest triton versions (#2588)

* Update triton kernel using _unsafe_update_src

* support old triton versions

* refactored changes to update triton kernel only once

* Update triton_ops.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <ilikekjw@gmail.com>
This commit is contained in:
ExtReMLapin 2025-06-26 02:02:54 +02:00 committed by GitHub
parent 5dff4db81a
commit 86899243e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,7 +60,7 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace( new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE", " LOAD_ALL_ROWS_HERE",
"\n".join( "\n".join(
[ [
@ -69,7 +69,8 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace(
new_kernel = new_kernel.replace(
" BUBBLESORT_HERE", " BUBBLESORT_HERE",
"\n\n".join( "\n\n".join(
[ [
@ -90,7 +91,14 @@ def median_kernel(filter_width: int):
] ]
), ),
) )
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if hasattr(kernel, "_unsafe_update_src") is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel return kernel