support old triton versions

This commit is contained in:
ExtReMLapin 2025-05-07 09:43:32 +02:00 committed by GitHub
parent 31ee44b071
commit 6335afea82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,8 +60,9 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
triton_3_kernel_update = hasattr(kernel, "_unsafe_update_src")
kernel._unsafe_update_src(kernel.src.replace(
new_kernel = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
@ -69,8 +70,13 @@ def median_kernel(filter_width: int):
for i in range(filter_width)
]
),
))
kernel._unsafe_update_src(kernel.src.replace(
)
if triton_3_kernel_update is True:
kernel._unsafe_update_src(new_kernel)
else:
kernel.src = new_kernel
new_kernel = kernel.src.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
@ -90,9 +96,20 @@ def median_kernel(filter_width: int):
for i in range(filter_width // 2 + 1)
]
),
))
kernel._unsafe_update_src(kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}"))
)
if triton_3_kernel_update is True:
kernel._unsafe_update_src(new_kernel)
else:
kernel.src = new_kernel
new_kernel = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
if triton_3_kernel_update is True:
kernel._unsafe_update_src(new_kernel)
kernel.hash = None
else:
kernel.src = new_kernel
return kernel