From 2946864a357fe9e5529c75aed2eb64579b6db368 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 13 Mar 2024 09:41:01 -0700 Subject: [PATCH] [Pallas GPU] Fix to how we pass num_warps/stages PiperOrigin-RevId: 615449420 --- jax/_src/pallas/triton/pallas_call_registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index cb8250d95693..46224cb32011 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -285,8 +285,8 @@ def pallas_call_lowering( raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" ) + triton_compiler_params = compiler_params.get("triton", compiler_params) triton_params = compiler_params.get("triton_params", {}) - triton_compiler_params = compiler_params.get("triton", {}) num_warps = triton_compiler_params.pop("num_warps", 4) if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for Pallas kernels")