Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
jax = ">= 0.6"
tensorflow = ">= 2.17"
numpy = ">= 2"
triton = ">= 3.4"
44 changes: 42 additions & 2 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices,
return client.release();
}
#else
*error = "ReactantExtra was not built with GPU support";
return nullptr;
*error = "ReactantExtra was not built with GPU support";
return nullptr;
#endif
}

Expand Down Expand Up @@ -731,16 +731,56 @@ std::vector<int64_t> row_major(int64_t dim) {
static void noop() {}

#ifdef REACTANT_CUDA

#include "third_party/gpus/cuda/include/cuda.h"

REACTANT_ABI int32_t ReactantCudaDriverGetVersion() {
int32_t data;
ReactantHandleCuResult(cuDriverGetVersion(&data));
return data;
}

REACTANT_ABI int32_t ReactantHermeticCudaGetVersion() { return CUDA_VERSION; }

REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMajor() {
CUdevice cuDevice;
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
int major;
ReactantHandleCuResult(cuDeviceGetAttribute(
&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice));
return major;
}

REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() {
CUdevice cuDevice;
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
int minor;
ReactantHandleCuResult(cuDeviceGetAttribute(
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice));
return minor;
}

REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() {
CUdevice cuDevice;
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
int warpSize;
ReactantHandleCuResult(cuDeviceGetAttribute(
&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, cuDevice));
return warpSize;
}

#else

REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }

REACTANT_ABI int32_t ReactantHermeticCudaGetVersion() { return 0; }

REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMajor() { return 0; }

REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; }

REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }

#endif

REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
Expand Down
23 changes: 22 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_toolchain_config")
Expand Down Expand Up @@ -979,6 +979,9 @@ cc_library(
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
Expand Down Expand Up @@ -1432,6 +1435,24 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "TritonExtJLIncGen",
tbl_outs = [
(
[
"--generator=jl-op-defs",
"--disable-module-wrap=0",
],
"TritonExt.jl",
),
],
tblgen = "//:mlir-jl-tblgen",
td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td",
deps = [
"@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles",
],
)

gentbl_cc_library(
name = "TPUJLIncGen",
tbl_outs = [
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "ccfcd699469d7244f103ef678cd9ed663bb24fd0"
ENZYMEXLA_COMMIT = "ad4bb14fcd9deb87b5e7ac440eeea02e27596c8d"

ENZYMEXLA_SHA256 = ""

Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ for file in [
"MPI.jl",
"MemRef.jl",
"SparseTensor.jl",
"TritonExt.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ export default defineConfig({
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TritonExt", link: "/api/dialects/tritonext" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
Expand Down Expand Up @@ -221,6 +222,7 @@ export default defineConfig({
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TritonExt", link: "/api/dialects/tritonext" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
Expand Down
11 changes: 11 additions & 0 deletions docs/src/api/dialects/tritonext.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
```@meta
CollapsedDocStrings = true
```

# TritonExt Dialect

Provides extensions to the Triton dialect.

```@autodocs
Modules = [Reactant.MLIR.Dialects.triton_ext]
```
8 changes: 0 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
return newa
end

function __init__()
if CUDA.functional() && !Reactant.precompiling()
cap = CUDA.capability(CUDA.device())
Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)"
end
return nothing
end

# In Julia v1.11.3 precompiling this module caches bad code:
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
@static if !Sys.isapple()
Expand Down
38 changes: 37 additions & 1 deletion ext/ReactantPythonCallExt/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
module ReactantPythonCallExt

using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
using PythonCall:
PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
using Reactant.Ops: @opcall
using Reactant_jll: Reactant_jll

const jaxptr = Ref{Py}()
const jnpptr = Ref{Py}()

const JAX_TRACING_SUPPORTED = Ref{Bool}(false)

const tritonptr = Ref{Py}()

const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false)

const tfptr = Ref{Py}()
const tf2xlaptr = Ref{Py}()
const npptr = Ref{Py}()
Expand All @@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict(
ComplexF64 => :complex64,
)

const MLIR_TYPE_STRING = Dict(
Float64 => "fp64",
Float32 => "fp32",
Float16 => "fp16",
Int64 => "i64",
Int32 => "i32",
Int16 => "i16",
Int8 => "i8",
UInt64 => "ui64",
UInt32 => "ui32",
UInt16 => "ui16",
UInt8 => "ui8",
Bool => "i1",
Reactant.F8E4M3FN => "fp8e4nv",
Reactant.F8E5M2FNUZ => "fp8e5b16",
Reactant.F8E4M3FNUZ => "fp8e4b8",
Reactant.F8E5M2 => "fp8e5",
)
if isdefined(Core, :BFloat16)
MLIR_TYPE_STRING[Core.BFloat16] = "bf16"
end

function __init__()
try
jaxptr[] = pyimport("jax")
Expand All @@ -43,6 +71,14 @@ function __init__()
be supported." exception = (err, catch_backtrace())
end

try
tritonptr[] = pyimport("triton")
TRITON_COMPILE_SUPPORTED[] = true
catch err
@warn "Failed to import triton. Compiling jax functions with triton won't be \
supported." exception = (err, catch_backtrace())
end

try
tfptr[] = pyimport("tensorflow")
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
Expand Down
6 changes: 3 additions & 3 deletions ext/ReactantPythonCallExt/overlays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@reactant_overlay function PythonCall.pycall(f::Py, args...)
@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...)
if Reactant.looped_any(Reactant.use_overlayed_version, args)
return pycall_with_jax_tracing(f, args...)
return overlayed_pycall(f, args...; kwargs...)
else
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...)
end
end
99 changes: 98 additions & 1 deletion ext/ReactantPythonCallExt/pycall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
)
end

function pycall_with_jax_tracing(f::Py, args...)
function overlayed_pycall(f::Py, args...; kwargs...)
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
# TODO: check for Autotuner and Heutistics as well
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
return overlayed_pycall_with_triton(f, args...; kwargs...)
else
@assert isempty(kwargs) "`kwargs` are not supported for jax traced functions."
return overlayed_pycall_with_jax_tracing(f, args...)
end
end

function overlayed_pycall_with_jax_tracing(f::Py, args...)
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")

seen_args = Reactant.OrderedIdDict()
Expand Down Expand Up @@ -35,3 +46,89 @@ function pycall_with_jax_tracing(f::Py, args...)
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
end

normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
@assert N <= 3
@assert all(grid .> 0)
return (grid..., ntuple(_ -> 1, 3 - N)...)
end

signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing
signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing
signature_string(x::T) where {T<:Number} = string(x), x
signature_string(x) = error("Unsupported argument type: $(typeof(x))")

# TODO: better name for hints?
function overlayed_pycall_with_triton(
kernel::Py,
args...;
grid,
blocks,
num_warps::Integer=1,
num_stages::Integer=3,
hints=nothing,
)
triton = tritonptr[]

grid = normalize_grid_and_blocks(grid)
blocks = normalize_grid_and_blocks(blocks)

mapped = map(signature_string, args)
signature = first.(mapped)
# TODO: are hints actually correctly set?
hints =
hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints)
constants = Dict(
kernel.arg_names[i - 1] => constant for
(i, constant) in enumerate(last.(mapped)) if constant !== nothing
)
for (k, v) in hints
v == 1 && (constants[kernel.arg_names[k - 1]] = v)
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature))
for k in keys(constants)
sigmap[k] = "constexpr"
end

for h in values(hints)
@assert h in (1, 16) "Only 1 and 16 are valid hints, got $h"
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

src = triton.compiler.ASTSource(;
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
)

target = triton.backends.compiler.GPUTarget(
"cuda",
parse(Int, Reactant.Compiler.cubinChip[][4:end]),
Reactant.Compiler.cuWarpSize[],
)
backend = triton.compiler.make_backend(target)
options = backend.parse_options(
pydict(
"num_warps" => num_warps,
"num_stages" => num_stages,
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
),
)

# Currently we are doing a double compilation here. can we do better?
# we are compiling here + lowering again inside enzymejax
ccinfo = triton.compile(src; target=target, options=options.__dict__)

return @opcall triton_call(
pyconvert(String, ccinfo.asm["source"]),
filter(x -> x isa Reactant.TracedType, args)...;
func_name=pyconvert(String, ccinfo.metadata.name),
grid_x=@opcall(constant(grid[1])),
grid_y=@opcall(constant(grid[2])),
grid_z=@opcall(constant(grid[3])),
block_x=@opcall(constant(blocks[1])),
block_y=@opcall(constant(blocks[2])),
block_z=@opcall(constant(blocks[3])),
)
end
1 change: 1 addition & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:no_triton,
]
end

Expand Down
Loading
Loading