Skip to content

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Sep 25, 2025

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)
using PythonCall, Reactant

pyimport("sys").path.append(@__DIR__)
kernel = pyimport("vector_add").add_kernel

x = Reactant.to_rarray(rand(Float32, 1024));
y = Reactant.to_rarray(rand(Float32, 1024));
out = Reactant.to_rarray(zeros(Float32, 1024));

@code_hlo kernel(
    x,
    y,
    out,
    length(x),
    64;
    grid=cld(length(x), 64),
    num_warps=1,
    num_stages=3,
    hints=Dict(1 => 16),
)

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from a7ece19 to f776758 Compare September 27, 2025 13:17
@avik-pal
Copy link
Collaborator Author

avik-pal commented Sep 27, 2025

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  module @tt_module_0 {
    tt.func @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
      %0 = tt.get_program_id x : i32
      %c64_i32 = arith.constant 64 : i32
      %c64_i32_0 = arith.constant 64 : i32
      %1 = arith.extsi %0 : i32 to i64
      %2 = arith.extsi %c64_i32_0 : i32 to i64
      %3 = arith.muli %1, %2 : i64
      %c2147483647_i64 = arith.constant 2147483647 : i64
      %c-2147483648_i64 = arith.constant -2147483648 : i64
      %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
      %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
      %6 = arith.andi %4, %5 : i1
      %7 = arith.muli %0, %c64_i32_0 : i32
      %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
      %9 = tt.splat %7 : i32 -> tensor<64xi32>
      %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
      %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
      %12 = arith.addi %10, %11 : tensor<64xi64>
      %c2147483647_i64_1 = arith.constant 2147483647 : i64
      %c-2147483648_i64_2 = arith.constant -2147483648 : i64
      %cst = arith.constant dense<2147483647> : tensor<64xi64>
      %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
      %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
      %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
      %15 = arith.andi %13, %14 : tensor<64xi1>
      %16 = arith.addi %9, %8 : tensor<64xi32>
      %c1024_i32 = arith.constant 1024 : i32
      %cst_4 = arith.constant dense<1024> : tensor<64xi32>
      %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
      %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
      %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
      %24 = arith.addf %20, %23 : tensor<64xf32>
      %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
      tt.return
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    enzymexla.triton_call @tt_module_0::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal
Copy link
Collaborator Author

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 95598f9 to 7f0afd8 Compare September 28, 2025 16:20
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f {
      tt.func private @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {enzymexla.memory_effects = ["read", "write"], noinline = false} {
        %0 = tt.get_program_id x : i32
        %c64_i32 = arith.constant 64 : i32
        %c64_i32_0 = arith.constant 64 : i32
        %1 = arith.extsi %0 : i32 to i64
        %2 = arith.extsi %c64_i32_0 : i32 to i64
        %3 = arith.muli %1, %2 : i64
        %c2147483647_i64 = arith.constant 2147483647 : i64
        %c-2147483648_i64 = arith.constant -2147483648 : i64
        %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
        %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
        %6 = arith.andi %4, %5 : i1
        %7 = arith.muli %0, %c64_i32_0 : i32
        %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
        %9 = tt.splat %7 : i32 -> tensor<64xi32>
        %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
        %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
        %12 = arith.addi %10, %11 : tensor<64xi64>
        %c2147483647_i64_1 = arith.constant 2147483647 : i64
        %c-2147483648_i64_2 = arith.constant -2147483648 : i64
        %cst = arith.constant dense<2147483647> : tensor<64xi64>
        %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
        %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
        %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
        %15 = arith.andi %13, %14 : tensor<64xi1>
        %16 = arith.addi %9, %8 : tensor<64xi32>
        %c1024_i32 = arith.constant 1024 : i32
        %cst_4 = arith.constant dense<1024> : tensor<64xi32>
        %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
        %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
        %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
        %24 = arith.addf %20, %23 : tensor<64xf32>
        %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
        tt.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 7f0afd8 to 4876110 Compare September 29, 2025 20:02
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cuda:120", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
      llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
      llvm.func @add_kernel_call_e72661bb113efd0f(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"], noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
        %0 = llvm.mlir.undef : vector<1xf32>
        %1 = llvm.mlir.constant(0 : i32) : i32
        %2 = llvm.mlir.constant(32 : i32) : i32
        %3 = llvm.mlir.constant(31 : i32) : i32
        %4 = llvm.mlir.constant(0 : index) : i32
        %5 = llvm.mlir.constant(1024 : i32) : i32
        %6 = llvm.mlir.constant(64 : i32) : i32
        %7 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
        %8 = llvm.mul %7, %6 : i32
        %9 = nvvm.read.ptx.sreg.tid.x : i32
        %10 = llvm.and %9, %3 : i32
        %11 = llvm.shl %10, %1 : i32
        %12 = llvm.or %1, %11 : i32
        %13 = llvm.or %12, %1 : i32
        %14 = llvm.and %13, %3 : i32
        %15 = llvm.lshr %14, %1 : i32
        %16 = llvm.xor %1, %15 : i32
        %17 = llvm.xor %1, %16 : i32
        %18 = llvm.xor %17, %1 : i32
        %19 = llvm.xor %17, %2 : i32
        %20 = llvm.add %18, %4 : i32
        %21 = llvm.add %19, %4 : i32
        %22 = llvm.add %8, %20 : i32
        %23 = llvm.add %8, %21 : i32
        %24 = llvm.icmp "slt" %22, %5 : i32
        %25 = llvm.icmp "slt" %23, %5 : i32
        %26 = llvm.getelementptr %arg0[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %27 = llvm.getelementptr %arg0[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %28 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %26, %24 : (!llvm.ptr<1>, i1) -> i32
        %29 = llvm.bitcast %28 : i32 to vector<1xf32>
        %30 = llvm.extractelement %29[%4 : i32] : vector<1xf32>
        %31 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %27, %25 : (!llvm.ptr<1>, i1) -> i32
        %32 = llvm.bitcast %31 : i32 to vector<1xf32>
        %33 = llvm.extractelement %32[%4 : i32] : vector<1xf32>
        %34 = llvm.getelementptr %arg1[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %35 = llvm.getelementptr %arg1[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %36 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %34, %24 : (!llvm.ptr<1>, i1) -> i32
        %37 = llvm.bitcast %36 : i32 to vector<1xf32>
        %38 = llvm.extractelement %37[%4 : i32] : vector<1xf32>
        %39 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %35, %25 : (!llvm.ptr<1>, i1) -> i32
        %40 = llvm.bitcast %39 : i32 to vector<1xf32>
        %41 = llvm.extractelement %40[%4 : i32] : vector<1xf32>
        %42 = llvm.fadd %30, %38 : f32
        %43 = llvm.fadd %33, %41 : f32
        %44 = llvm.getelementptr %arg2[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %45 = llvm.getelementptr %arg2[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %46 = llvm.insertelement %42, %0[%1 : i32] : vector<1xf32>
        %47 = llvm.bitcast %46 : vector<1xf32> to i32
        %48 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %47, %44, %24 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        %49 = llvm.insertelement %43, %0[%1 : i32] : vector<1xf32>
        %50 = llvm.bitcast %49 : vector<1xf32> to i32
        %51 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %50, %45, %25 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        llvm.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<16> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c_1, %c_0, %c_0) shmem = %c (%arg0, %arg1, %arg2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    return
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 3315b07 to 4a9a1ce Compare October 1, 2025 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant