Skip to content

Commit

Permalink
[jax_triton] Add parameter allowing user to compile for specific comp…
Browse files Browse the repository at this point in the history
…ute capability.

PiperOrigin-RevId: 612647104
  • Loading branch information
chr1sj0nes authored and jax authors committed Mar 5, 2024
1 parent bc3f123 commit 9996b1f
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def compile_jaxpr(
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
target = ("cuda", arch)
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
target = ("cuda", compute_capability)
cuda_backend = cb.CUDABackend(target)
cuda_options = cuda_backend.parse_options(
dict(
Expand All @@ -78,13 +78,11 @@ def compile_jaxpr(
)

ttir = str(lowering_result.module)
ptx, name, shared_mem_bytes, compute_capability, _ = (
compile_ttir_to_ptx_inplace(
lowering_result.module,
cuda_backend,
cuda_options,
device=device,
)
ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace(
lowering_result.module,
cuda_backend,
cuda_options,
compute_capability,
)
return CompilationResult(
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
Expand Down

0 comments on commit 9996b1f

Please sign in to comment.