diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index edd4564..13d417b 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -60,7 +60,7 @@ def median_kernel(filter_width: int): tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 kernel = triton.JITFunction(kernel.fn) - kernel.src = kernel.src.replace( + new_kernel = kernel.src.replace( " LOAD_ALL_ROWS_HERE", "\n".join( [ @@ -69,7 +69,8 @@ def median_kernel(filter_width: int): ] ), ) - kernel.src = kernel.src.replace( + + new_kernel = new_kernel.replace( " BUBBLESORT_HERE", "\n\n".join( [ @@ -90,7 +91,14 @@ def median_kernel(filter_width: int): ] ), ) - kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + + new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + + if hasattr(kernel, "_unsafe_update_src") is True: + kernel._unsafe_update_src(new_kernel) + kernel.hash = None + else: + kernel.src = new_kernel return kernel