diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 883a9be249d5..8735ceff4a08 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -63,8 +63,8 @@ def compile_jaxpr( # which is fine when we have multiple of the same GPU but this won't work in # general. device = 0 - arch = triton_kernel_call_lib.get_compute_capability(device) - target = ("cuda", arch) + compute_capability = triton_kernel_call_lib.get_compute_capability(device) + target = ("cuda", compute_capability) cuda_backend = cb.CUDABackend(target) cuda_options = cuda_backend.parse_options( dict( @@ -78,13 +78,11 @@ def compile_jaxpr( ) ttir = str(lowering_result.module) - ptx, name, shared_mem_bytes, compute_capability, _ = ( - compile_ttir_to_ptx_inplace( - lowering_result.module, - cuda_backend, - cuda_options, - device=device, - ) + ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace( + lowering_result.module, + cuda_backend, + cuda_options, + compute_capability, ) return CompilationResult( name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result