diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index cb8250d95693..46224cb32011 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -285,8 +285,8 @@ def pallas_call_lowering( raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" ) + triton_compiler_params = compiler_params.get("triton", compiler_params) triton_params = compiler_params.get("triton_params", {}) - triton_compiler_params = compiler_params.get("triton", {}) num_warps = triton_compiler_params.pop("num_warps", 4) if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for Pallas kernels")