From 31ee44b0710afd84c1289ae1ec29fa44a301b1e7 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Wed, 7 May 2025 09:37:48 +0200 Subject: [PATCH] Update triton kernel using _unsafe_update_src --- whisper/triton_ops.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index edd4564..8eb30ca 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -60,7 +60,8 @@ def median_kernel(filter_width: int): tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 kernel = triton.JITFunction(kernel.fn) - kernel.src = kernel.src.replace( + + kernel._unsafe_update_src(kernel.src.replace( " LOAD_ALL_ROWS_HERE", "\n".join( [ @@ -68,8 +69,8 @@ def median_kernel(filter_width: int): for i in range(filter_width) ] ), - ) - kernel.src = kernel.src.replace( + )) + kernel._unsafe_update_src(kernel.src.replace( " BUBBLESORT_HERE", "\n\n".join( [ @@ -89,8 +90,9 @@ def median_kernel(filter_width: int): for i in range(filter_width // 2 + 1) ] ), - ) - kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + )) + kernel._unsafe_update_src(kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")) + kernel.hash = None return kernel