-
Notifications
You must be signed in to change notification settings - Fork 30
feat: initial triton setup #1702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
avik-pal
wants to merge
15
commits into
main
Choose a base branch
from
ap/triton_integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
+632
−92
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
a7ece19
to
f776758
Compare
module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
module @tt_module_0 {
tt.func @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
%0 = tt.get_program_id x : i32
%c64_i32 = arith.constant 64 : i32
%c64_i32_0 = arith.constant 64 : i32
%1 = arith.extsi %0 : i32 to i64
%2 = arith.extsi %c64_i32_0 : i32 to i64
%3 = arith.muli %1, %2 : i64
%c2147483647_i64 = arith.constant 2147483647 : i64
%c-2147483648_i64 = arith.constant -2147483648 : i64
%4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
%5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
%6 = arith.andi %4, %5 : i1
%7 = arith.muli %0, %c64_i32_0 : i32
%8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%9 = tt.splat %7 : i32 -> tensor<64xi32>
%10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
%11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
%12 = arith.addi %10, %11 : tensor<64xi64>
%c2147483647_i64_1 = arith.constant 2147483647 : i64
%c-2147483648_i64_2 = arith.constant -2147483648 : i64
%cst = arith.constant dense<2147483647> : tensor<64xi64>
%13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
%cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
%14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
%15 = arith.andi %13, %14 : tensor<64xi1>
%16 = arith.addi %9, %8 : tensor<64xi32>
%c1024_i32 = arith.constant 1024 : i32
%cst_4 = arith.constant dense<1024> : tensor<64xi32>
%17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
%18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
%21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
%24 = arith.addf %20, %23 : tensor<64xf32>
%25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%c = stablehlo.constant dense<16> : tensor<i64>
%c_0 = stablehlo.constant dense<1> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<0> : tensor<i64>
enzymexla.triton_call @tt_module_0::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
%3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
}
} |
95598f9
to
7f0afd8
Compare
module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
builtin.module @add_kernel_module_e72661bb113efd0f {
tt.func private @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {enzymexla.memory_effects = ["read", "write"], noinline = false} {
%0 = tt.get_program_id x : i32
%c64_i32 = arith.constant 64 : i32
%c64_i32_0 = arith.constant 64 : i32
%1 = arith.extsi %0 : i32 to i64
%2 = arith.extsi %c64_i32_0 : i32 to i64
%3 = arith.muli %1, %2 : i64
%c2147483647_i64 = arith.constant 2147483647 : i64
%c-2147483648_i64 = arith.constant -2147483648 : i64
%4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
%5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
%6 = arith.andi %4, %5 : i1
%7 = arith.muli %0, %c64_i32_0 : i32
%8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%9 = tt.splat %7 : i32 -> tensor<64xi32>
%10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
%11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
%12 = arith.addi %10, %11 : tensor<64xi64>
%c2147483647_i64_1 = arith.constant 2147483647 : i64
%c-2147483648_i64_2 = arith.constant -2147483648 : i64
%cst = arith.constant dense<2147483647> : tensor<64xi64>
%13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
%cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
%14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
%15 = arith.andi %13, %14 : tensor<64xi1>
%16 = arith.addi %9, %8 : tensor<64xi32>
%c1024_i32 = arith.constant 1024 : i32
%cst_4 = arith.constant dense<1024> : tensor<64xi32>
%17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
%18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
%21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
%24 = arith.addf %20, %23 : tensor<64xf32>
%25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
}
func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%c = stablehlo.constant dense<16> : tensor<i64>
%c_0 = stablehlo.constant dense<1> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<0> : tensor<i64>
triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
%3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
%5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
}
} |
7f0afd8
to
4876110
Compare
module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
builtin.module @add_kernel_module_e72661bb113efd0f attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cuda:120", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.func @add_kernel_call_e72661bb113efd0f(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"], noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
%0 = llvm.mlir.undef : vector<1xf32>
%1 = llvm.mlir.constant(0 : i32) : i32
%2 = llvm.mlir.constant(32 : i32) : i32
%3 = llvm.mlir.constant(31 : i32) : i32
%4 = llvm.mlir.constant(0 : index) : i32
%5 = llvm.mlir.constant(1024 : i32) : i32
%6 = llvm.mlir.constant(64 : i32) : i32
%7 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
%8 = llvm.mul %7, %6 : i32
%9 = nvvm.read.ptx.sreg.tid.x : i32
%10 = llvm.and %9, %3 : i32
%11 = llvm.shl %10, %1 : i32
%12 = llvm.or %1, %11 : i32
%13 = llvm.or %12, %1 : i32
%14 = llvm.and %13, %3 : i32
%15 = llvm.lshr %14, %1 : i32
%16 = llvm.xor %1, %15 : i32
%17 = llvm.xor %1, %16 : i32
%18 = llvm.xor %17, %1 : i32
%19 = llvm.xor %17, %2 : i32
%20 = llvm.add %18, %4 : i32
%21 = llvm.add %19, %4 : i32
%22 = llvm.add %8, %20 : i32
%23 = llvm.add %8, %21 : i32
%24 = llvm.icmp "slt" %22, %5 : i32
%25 = llvm.icmp "slt" %23, %5 : i32
%26 = llvm.getelementptr %arg0[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%27 = llvm.getelementptr %arg0[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%28 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %26, %24 : (!llvm.ptr<1>, i1) -> i32
%29 = llvm.bitcast %28 : i32 to vector<1xf32>
%30 = llvm.extractelement %29[%4 : i32] : vector<1xf32>
%31 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %27, %25 : (!llvm.ptr<1>, i1) -> i32
%32 = llvm.bitcast %31 : i32 to vector<1xf32>
%33 = llvm.extractelement %32[%4 : i32] : vector<1xf32>
%34 = llvm.getelementptr %arg1[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%35 = llvm.getelementptr %arg1[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%36 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %34, %24 : (!llvm.ptr<1>, i1) -> i32
%37 = llvm.bitcast %36 : i32 to vector<1xf32>
%38 = llvm.extractelement %37[%4 : i32] : vector<1xf32>
%39 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %35, %25 : (!llvm.ptr<1>, i1) -> i32
%40 = llvm.bitcast %39 : i32 to vector<1xf32>
%41 = llvm.extractelement %40[%4 : i32] : vector<1xf32>
%42 = llvm.fadd %30, %38 : f32
%43 = llvm.fadd %33, %41 : f32
%44 = llvm.getelementptr %arg2[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%45 = llvm.getelementptr %arg2[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%46 = llvm.insertelement %42, %0[%1 : i32] : vector<1xf32>
%47 = llvm.bitcast %46 : vector<1xf32> to i32
%48 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %47, %44, %24 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
%49 = llvm.insertelement %43, %0[%1 : i32] : vector<1xf32>
%50 = llvm.bitcast %49 : vector<1xf32> to i32
%51 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %50, %45, %25 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
llvm.return
}
}
}
func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%c = stablehlo.constant dense<0> : tensor<i64>
%c_0 = stablehlo.constant dense<1> : tensor<i64>
%c_1 = stablehlo.constant dense<16> : tensor<i64>
triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c_1, %c_0, %c_0) shmem = %c (%arg0, %arg1, %arg2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
return
}
} |
3315b07
to
4a9a1ce
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.