From 6335afea82b8d693863c98bcbac20eef9b6e302b Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Wed, 7 May 2025 09:43:32 +0200 Subject: [PATCH] support old triton versions --- whisper/triton_ops.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index 8eb30ca..0dfccfb 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -60,8 +60,9 @@ def median_kernel(filter_width: int): tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 kernel = triton.JITFunction(kernel.fn) + triton_3_kernel_update = hasattr(kernel, "_unsafe_update_src") - kernel._unsafe_update_src(kernel.src.replace( + new_kernel = kernel.src.replace( " LOAD_ALL_ROWS_HERE", "\n".join( [ @@ -69,8 +70,13 @@ def median_kernel(filter_width: int): for i in range(filter_width) ] ), - )) - kernel._unsafe_update_src(kernel.src.replace( + ) + if triton_3_kernel_update is True: + kernel._unsafe_update_src(new_kernel) + else: + kernel.src = new_kernel + + new_kernel = kernel.src.replace( " BUBBLESORT_HERE", "\n\n".join( [ @@ -90,9 +96,20 @@ def median_kernel(filter_width: int): for i in range(filter_width // 2 + 1) ] ), - )) - kernel._unsafe_update_src(kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")) - kernel.hash = None + ) + if triton_3_kernel_update is True: + kernel._unsafe_update_src(new_kernel) + else: + kernel.src = new_kernel + + new_kernel = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") + + if triton_3_kernel_update is True: + kernel._unsafe_update_src(new_kernel) + kernel.hash = None + else: + kernel.src = new_kernel + return kernel