From 626db677aa00f37295149156319c695986e6dda5 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 30 Sep 2024 01:52:06 -0700 Subject: [PATCH] Integrate Triton up to [6fa4f504](https://github.com/openai/triton/commits/6fa4f504d61e48f7cea454ce0c1f6169907d5aaa) PiperOrigin-RevId: 680473520 --- jax_triton/triton_lib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 96d06693..33459d3e 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -538,7 +538,8 @@ def triton_kernel_call_lowering( named_args = dict(unsafe_zip(fn.arg_names, args)) if isinstance(fn, autotuner.Autotuner): - if any(idx not in fn.key_idx for idx, _, _ in scalar_args): + key_idxs = [fn.arg_names.index(k) for k in fn.keys] + if any(idx not in key_idxs for idx, _, _ in scalar_args): logging.warning( "Auto-tuning key does not include all scalar arguments. " "We may perform redundant auto-tuning."