Update triton kernel using _unsafe_update_src

This commit is contained in:
ExtReMLapin 2025-05-07 09:37:48 +02:00 committed by GitHub
parent 517a43ecd1
commit 31ee44b071
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,7 +60,8 @@ def median_kernel(filter_width: int):
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(
kernel._unsafe_update_src(kernel.src.replace(
" LOAD_ALL_ROWS_HERE", " LOAD_ALL_ROWS_HERE",
"\n".join( "\n".join(
[ [
@ -68,8 +69,8 @@ def median_kernel(filter_width: int):
for i in range(filter_width) for i in range(filter_width)
] ]
), ),
) ))
kernel.src = kernel.src.replace( kernel._unsafe_update_src(kernel.src.replace(
" BUBBLESORT_HERE", " BUBBLESORT_HERE",
"\n\n".join( "\n\n".join(
[ [
@ -89,8 +90,9 @@ def median_kernel(filter_width: int):
for i in range(filter_width // 2 + 1) for i in range(filter_width // 2 + 1)
] ]
), ),
) ))
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") kernel._unsafe_update_src(kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}"))
kernel.hash = None
return kernel return kernel