@@ -792,6 +792,53 @@ function stream2token(source::Value; result::IR.Type, location=Location())
792
792
)
793
793
end
794
794
795
+ function triton_call (
796
+ gridx:: Value ,
797
+ gridy:: Value ,
798
+ gridz:: Value ,
799
+ shmem:: Value ,
800
+ inputs:: Vector{Value} ;
801
+ result_0:: Vector{IR.Type} ,
802
+ fn,
803
+ backend_config= nothing ,
804
+ operand_layouts= nothing ,
805
+ result_layouts= nothing ,
806
+ arg_attrs= nothing ,
807
+ res_attrs= nothing ,
808
+ output_operand_aliases= nothing ,
809
+ xla_side_effect_free= nothing ,
810
+ location= Location (),
811
+ )
812
+ op_ty_results = IR. Type[result_0... ,]
813
+ operands = Value[gridx, gridy, gridz, shmem, inputs... ]
814
+ owned_regions = Region[]
815
+ successors = Block[]
816
+ attributes = NamedAttribute[namedattribute (" fn" , fn),]
817
+ ! isnothing (backend_config) &&
818
+ push! (attributes, namedattribute (" backend_config" , backend_config))
819
+ ! isnothing (operand_layouts) &&
820
+ push! (attributes, namedattribute (" operand_layouts" , operand_layouts))
821
+ ! isnothing (result_layouts) &&
822
+ push! (attributes, namedattribute (" result_layouts" , result_layouts))
823
+ ! isnothing (arg_attrs) && push! (attributes, namedattribute (" arg_attrs" , arg_attrs))
824
+ ! isnothing (res_attrs) && push! (attributes, namedattribute (" res_attrs" , res_attrs))
825
+ ! isnothing (output_operand_aliases) &&
826
+ push! (attributes, namedattribute (" output_operand_aliases" , output_operand_aliases))
827
+ ! isnothing (xla_side_effect_free) &&
828
+ push! (attributes, namedattribute (" xla_side_effect_free" , xla_side_effect_free))
829
+
830
+ return create_operation (
831
+ " enzymexla.triton_call" ,
832
+ location;
833
+ operands,
834
+ owned_regions,
835
+ successors,
836
+ attributes,
837
+ results= op_ty_results,
838
+ result_inference= false ,
839
+ )
840
+ end
841
+
795
842
function wrap (
796
843
operand:: Value ;
797
844
result= nothing :: Union{Nothing,IR.Type} ,
0 commit comments