Skip to content

Commit

Permalink
[Pallas GPU] Fix to how we pass num_warps/stages
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615599862
  • Loading branch information
sharadmv authored and jax authors committed Mar 14, 2024
1 parent d4f532e commit 6046d7d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
triton_compiler_params = compiler_params.get("triton", compiler_params)
triton_params = compiler_params.get("triton_params", {})
triton_compiler_params = compiler_params.get("triton", {})
num_warps = triton_compiler_params.pop("num_warps", 4)
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for Pallas kernels")
Expand Down

0 comments on commit 6046d7d

Please sign in to comment.