diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index 0dfccfb..c0ff935 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -60,7 +60,6 @@ def median_kernel(filter_width: int): tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 kernel = triton.JITFunction(kernel.fn) - triton_3_kernel_update = hasattr(kernel, "_unsafe_update_src") new_kernel = kernel.src.replace( " LOAD_ALL_ROWS_HERE", @@ -71,12 +70,8 @@ def median_kernel(filter_width: int): ] ), ) - if triton_3_kernel_update is True: - kernel._unsafe_update_src(new_kernel) - else: - kernel.src = new_kernel - - new_kernel = kernel.src.replace( + + new_kernel = new_kernel.replace( " BUBBLESORT_HERE", "\n\n".join( [ @@ -97,14 +92,10 @@ def median_kernel(filter_width: int): ] ), ) - if triton_3_kernel_update is True: - kernel._unsafe_update_src(new_kernel) - else: - kernel.src = new_kernel + + new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") - new_kernel = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") - - if triton_3_kernel_update is True: + if hasattr(kernel, "_unsafe_update_src") is True: kernel._unsafe_update_src(new_kernel) kernel.hash = None else: