diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..925c357e1a 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 1226b7757d..87b9269e20 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1297,6 +1297,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1712,6 +1713,7 @@ function compile_mlir!( blas_int_width = sizeof(BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}" legalize_chlo_to_stablehlo = if legalize_stablehlo_to_mhlo || compile_options.legalize_chlo_to_stablehlo @@ -1878,6 +1880,67 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/Reactant.jl b/src/Reactant.jl index df78eccae2..97c733a58b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -242,6 +242,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") +include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..df1813d1b5 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -229,6 +229,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -405,6 +406,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..a81992eb71 --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..70fb6c0618 --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,346 @@ +using ..Reactant: MLIR + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + if size(dest) != size(tostore[i]) + if length(size(dest)) != length(size(tostore[i])) + @ccall printf( + "Shape size mismatch in constrained sample: %zd != %zd\n"::Cstring, + length(size(dest))::Csize_t, + length(size(tostore[i]))::Csize_t, + )::Cvoid + return nothing + end + for i in 1:length(size(dest)) + d = size(dest)[i] + t = size(tostore[i])[i] + if d != t + @ccall printf( + "Shape mismatch in `%zd`th dimension of constrained sample: %zd != %zd\n"::Cstring, + i::Csize_t, + size(dest)[i]::Csize_t, + size(tostore[i])[i]::Csize_t, + )::Cvoid + return nothing + end + end + end + + dest .= tostore[i] + end + end + + return nothing +end + +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl new file mode 100644 index 0000000000..a0446a35bb --- /dev/null +++ b/src/probprog/MH.jl @@ -0,0 +1,100 @@ +using ..Reactant: ConcreteRNumber, TracedRArray + +function mh( + rng::AbstractRNG, + original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "mh") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + if original_trace isa TracedRArray{UInt64,0} + # Use MLIR data from previous iteration + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [original_trace.mlir_data]; outputs=[trace_ty] + ), + 1, + ) + else + # First iteration: create constant from pointer + trace_ptr = reinterpret(UInt64, pointer_from_objref(original_trace)) + tt = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64)) + splatattr = MLIR.API.mlirDenseElementsAttrUInt64SplatGet(tt, trace_ptr) + cst_op = MLIR.Dialects.stablehlo.constant(; output=tt, value=splatattr) + trace_ptr_val = MLIR.IR.result(cst_op) + + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [trace_ptr_val]; outputs=[trace_ty] + ), + 1, + ) + end + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + mh_op = MLIR.Dialects.enzyme.mh( + inputs, + trace_val; + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=out_tys[1], # by convention + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # Return (new_trace, accepted, output_rng_state) + process_probprog_outputs( + mh_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + ) + + new_trace_val = MLIR.IR.result(mh_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ()) + + return new_trace, accepted, result +end + +const metropolis_hastings = mh \ No newline at end of file diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl new file mode 100644 index 0000000000..cb3e417b26 --- /dev/null +++ b/src/probprog/Modeling.jl @@ -0,0 +1,275 @@ +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Utils.jl") + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "sample" + ) + + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute + + # Construct logpdf attribute if `logpdf` function is provided. + logpdf_attr = nothing + if logpdf isa Function + samples = f(args_with_rng...) + + # Assume that logpdf parameters follow `(sample, args...)` convention. + logpdf_args = (samples, args...) + + logpdf_mlir = TracedUtils.make_mlir_fn( + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:result, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + + sample_op = MLIR.Dialects.enzyme.sample( + inputs; + outputs=out_tys, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_attr, + name=Base.String(symbol), + ) + + process_probprog_outputs( + sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "call" + ) + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) + + process_probprog_outputs( + call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +# Gen-like helper function. +function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = nothing + + compiled_fn = @compile optimize = :probprog simulate(rng, f, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin + trace, _, _ = compiled_fn(rng, f, args...) + + while !isready(trace) + yield() + end + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + + trace.rng = rng + trace.fn = f + trace.args = args + + return trace, trace.weight +end + +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "simulate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr + ) + + process_probprog_outputs( + simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) + + return trace, weight, result +end + +# Gen-like helper function. +function generate_( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint::Constraint=Constraint(), +) where {Nargs} + trace = nothing + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + constrained_addresses = extract_addresses(constraint) + + function wrapper_fn(rng, constraint_ptr, args...) + return generate(rng, f, args...; constraint_ptr, constrained_addresses) + end + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin + trace, _, _ = compiled_fn(rng, constraint_ptr, args...) + + while !isready(trace) + yield() + end + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + + trace.rng = rng + trace.fn = f + trace.args = args + + return trace, trace.weight +end + +function generate( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint_ptr::TracedRNumber, + constrained_addresses::Set{Address}, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "generate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + inputs, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), + ) + + process_probprog_outputs( + generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) + + return trace, weight, result +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..3e6dfccdde --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,25 @@ +module ProbProg + +using ..Reactant: + MLIR, TracedUtils, AbstractRNG, TracedRArray, TracedRNumber, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Display.jl") +include("MH.jl") + +# Types. +export ProbProgTrace, Constraint, Selection, Address + +# Utility functions. +export get_choices, select + +# Core MLIR ops. +export sample, untraced_call, simulate, generate, mh + +# Gen-like helper functions. +export simulate_, generate_ + +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..9cde0752c3 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,88 @@ +using Base: ReentrantLock + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + rng::Union{Nothing,AbstractRNG} + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} + + function ProbProgTrace() + return new( + Dict{Symbol,Any}(), + nothing, + nothing, + Dict{Symbol,Any}(), + nothing, + nothing, + nothing, + ) + end +end + +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end + +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + +const Selection = Set{Address} + +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() + +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) + try + _probprog_refs[tr] = tr + finally + unlock(_probprog_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(addrs::Address...) = Set{Address}([addrs...]) \ No newline at end of file diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl new file mode 100644 index 0000000000..086e5bcb80 --- /dev/null +++ b/src/probprog/Utils.jl @@ -0,0 +1,125 @@ +using ..Reactant: MLIR, TracedUtils + +""" + process_probprog_function(f, args_with_rng, op_name) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when wrapped) +- **Index 3+**: Remaining arguments + +This wrapper ensures the RNG state is threaded through as the first result, +followed by the actual function results. +""" +function process_probprog_function(f, args_with_rng, op_name) + argprefix = gensym(op_name * "arg") + resprefix = gensym(op_name * "result") + resargprefix = gensym(op_name * "resarg") + + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + mlir_fn_res = TracedUtils.make_mlir_fn( + wrapper_fn, + args_with_rng, + (), + string(f), + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + return mlir_fn_res, argprefix, resprefix, resargprefix +end + +""" + process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments +""" +function process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 2 && fnwrap + TracedUtils.push_val!(inputs, f, path[3:end]) + else + if fnwrap && idx > 1 + idx -= 1 + end + TracedUtils.push_val!(inputs, args_with_rng[idx], path[3:end]) + end + end + return inputs +end + +""" + process_probprog_outputs(op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix, start_idx=0, rng_only=false) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments + +When setting results, the function checks: +1. If result path matches `resprefix`, store in `result` +2. If result path matches `argprefix`, store in `args_with_rng` (adjust indices for wrapped function) + +`start_idx` varies depending on the ProbProg operation: +- `sample` and `untraced_call` return only function outputs: + Use `start_idx=0`: `linear_results[i]` corresponds to `op.result[i]` +- `simulate` and `generate` return trace, weight, then outputs: + Use `start_idx=2`: `linear_results[i]` corresponds to `op.result[i+2]` +- `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs): + Use `start_idx=2, rng_only=true`: only process first result (rng_state) + +`rng_only`: When true, only process the first result (RNG state), skipping model outputs +""" +function process_probprog_outputs( + op, + linear_results, + result, + f, + args_with_rng, + fnwrap, + resprefix, + argprefix, + start_idx=0, + rng_only=false, +) + num_to_process = rng_only ? 1 : length(linear_results) + + for i in 1:num_to_process + res = linear_results[i] + resv = MLIR.IR.result(op, i + start_idx) + + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + end + + if TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if fnwrap && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap && idx > 2 + idx -= 1 + end + TracedUtils.set!(args_with_rng[idx], path[3:end], resv) + end + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) + TracedUtils.set!(res, (), resv) + end + end +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..fdbcf20b6d --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,149 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function two_normals(rng, μ, σ, shape) + _, x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + _, y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + _, u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + +@testset "Generate" begin + @testset "unconstrained" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constrained" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] + + expected_weight = + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + constrained_addresses = ProbProg.extract_addresses(constraint1) + + constraint_ptr1 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint1)) + ) + + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( + rng, model, μ, σ, shape; constraint_ptr, constrained_addresses + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) + + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) + + while !isready(trace1) + yield() + end + end + trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) + + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) + constraint_ptr2 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint2)) + ) + + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) + + while !isready(trace2) + yield() + end + end + trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) + + @test trace1.choices[:s][1] != trace2.choices[:s][1] + end +end diff --git a/test/probprog/mh.jl b/test/probprog/mh.jl new file mode 100644 index 0000000000..b66d9d55e8 --- /dev/null +++ b/test/probprog/mh.jl @@ -0,0 +1,100 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, slope = ProbProg.sample( + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + _, intercept = ProbProg.sample( + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + _, ys = ProbProg.sample( + rng, + normal, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return ys +end + +function mh_program(rng, t, model, xs, num_iters) + trace_ptr_val = reinterpret(UInt64, pointer_from_objref(t)) + trace_ptr = Reactant.Ops.fill(trace_ptr_val, Int64[]) + + @trace for _ in 1:num_iters + trace_ptr, _ = ProbProg.mh( + rng, trace_ptr, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) + ) + trace_ptr, _ = ProbProg.mh( + rng, + trace_ptr, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:intercept)), + ) + end + + return trace_ptr +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace, _ = ProbProg.simulate_(rng, model, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + obs = ProbProg.Constraint(:ys => (ys,)) + init_trace, _ = ProbProg.generate_(rng, model, xs; constraint=obs) + + code = @code_hlo optimize = false mh_program(rng, init_trace, model, xs, 10000) + println(code) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer init_trace begin + trace_ptr = @compile optimize = :probprog mh_program( + rng, init_trace, model, xs, 10000 + ) + + while !isready(trace_ptr) + yield() + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace_ptr)[1])) + end + + slope = trace.choices[:slope][1] + intercept = trace.choices[:intercept][1] + @show slope, intercept + + @test slope ≈ -2.0 rtol = 0.05 + @test intercept ≈ 10.0 rtol = 0.05 + end +end \ No newline at end of file diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..b7889c46dd --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,88 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function one_sample(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + return s +end + +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, μ, σ, shape) + return t +end + +function compose(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, s, σ, shape) + return t +end + +@testset "test" begin + @testset "normal_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + rng1 = ReactantRNG(copy(seed)) + + _, X = @jit optimize = :probprog ProbProg.untraced_call( + rng1, one_sample, μ, σ, shape + ) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + _, Y = @jit optimize = :probprog ProbProg.untraced_call( + rng2, two_samples, μ, σ, shape + ) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..423a818ebf --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,114 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function product_two_normals(rng, μ, σ, shape) + _, a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + _, b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + return t +end + +@testset "Simulate" begin + @testset "hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate(rng, model, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + + after = @code_hlo optimize = :probprog ProbProg.simulate(rng, model, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model, μ, σ, shape) + + @test size(trace.retval[1]) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + @test trace.weight isa Float64 + end + + @testset "simple_fake" begin + op(_, x, y) = x * y' + logpdf(res, _, _) = sum(res) + function fake_model(rng, x, y) + _, res = ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + return res + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, weight = ProbProg.simulate_(rng, fake_model, x_ra, y_ra) + + @test Array(trace.retval[1]) == op(rng, x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model2, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5edaa478e5..ccb96984ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,4 +68,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + end end