From fcb08e4c91cf8bd06db60232acf21b81c675e99e Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Wed, 10 Sep 2025 03:39:31 +0200 Subject: [PATCH 1/4] Add apply function for type stable calls of functions --- src/LightSumTypes.jl | 57 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/LightSumTypes.jl b/src/LightSumTypes.jl index 3ea8ab1..5cc3232 100644 --- a/src/LightSumTypes.jl +++ b/src/LightSumTypes.jl @@ -194,6 +194,63 @@ is_sumtype(T::Type) = false function variant_idx end + function _is_sumtype_structurally(T) + return T isa DataType && fieldcount(T) == 1 && fieldname(T, 1) === :variants && fieldtype(T, 1) isa Union +end + +function _get_variant_types(T_sum) + field_T = fieldtype(T_sum, 1) + + !(field_T isa Union) && return [field_T] + + types = [] + curr = field_T + while curr isa Union + push!(types, curr.a) + curr = curr.b + end + push!(types, curr) + return types +end + +@generated function apply(f::F, args::Tuple) where {F} + + + args = fieldtypes(args) + sumtype_args = [(i, T) for (i, T) in enumerate(args) if _is_sumtype_structurally(T)] + + if isempty(sumtype_args) + return :(f(args...)) + end + + final_args = Any[:(args[$i]) for i in 1:length(args)] + for (idx, T) in sumtype_args + final_args[idx] = Symbol("v_", idx) + end + + body = :(f($(final_args...))) + + for (idx, T) in reverse(sumtype_args) + unwrapped_var = Symbol("v_", idx) + + variant_types = _get_variant_types(T) + + branch_expr = :(error("THIS_SHOULD_BE_UNREACHABLE")) + for V_type in reverse(variant_types) + condition = :($unwrapped_var isa $V_type) + branch_expr = Expr(:elseif, condition, body, branch_expr) + end + branch_expr = Expr(:if, branch_expr.args...) + + body = quote + let $(unwrapped_var) = $LightSumTypes.unwrap(args[$idx]) + $branch_expr + end + end + end + return body +end + include("precompile.jl") end From f1c7cd2833e5a852a9a59225e798882cc0200066 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Wed, 10 Sep 2025 03:40:50 +0200 Subject: [PATCH 2/4] Add 'apply' to exported functions in LightSumTypes --- src/LightSumTypes.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/LightSumTypes.jl b/src/LightSumTypes.jl index 5cc3232..c251357 100644 --- a/src/LightSumTypes.jl +++ b/src/LightSumTypes.jl @@ -3,7 +3,7 @@ module LightSumTypes using MacroTools: namify -export @sumtype, sumtype_expr, variant, variantof, allvariants, is_sumtype +export @sumtype, sumtype_expr, variant, variantof, allvariants, is_sumtype, apply unwrap(sumt) = getfield(sumt, :variants) @@ -194,7 +194,7 @@ is_sumtype(T::Type) = false function variant_idx end - function _is_sumtype_structurally(T) +function _is_sumtype_structurally(T) return T isa DataType && fieldcount(T) == 1 && fieldname(T, 1) === :variants && fieldtype(T, 1) isa Union end From 7157817c94d38f3c439b3e16e4bf98d035660ca2 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:51:46 +0200 Subject: [PATCH 3/4] Update LightSumTypes.jl --- src/LightSumTypes.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/LightSumTypes.jl b/src/LightSumTypes.jl index c251357..1a275d8 100644 --- a/src/LightSumTypes.jl +++ b/src/LightSumTypes.jl @@ -200,9 +200,6 @@ end function _get_variant_types(T_sum) field_T = fieldtype(T_sum, 1) - - !(field_T isa Union) && return [field_T] - types = [] curr = field_T while curr isa Union From 5bdff85650727f9c81b7667ad051abbcfac27b2e Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:52:37 +0200 Subject: [PATCH 4/4] Update LightSumTypes.jl --- src/LightSumTypes.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/LightSumTypes.jl b/src/LightSumTypes.jl index 1a275d8..c2f3503 100644 --- a/src/LightSumTypes.jl +++ b/src/LightSumTypes.jl @@ -211,15 +211,10 @@ function _get_variant_types(T_sum) end @generated function apply(f::F, args::Tuple) where {F} - args = fieldtypes(args) sumtype_args = [(i, T) for (i, T) in enumerate(args) if _is_sumtype_structurally(T)] - if isempty(sumtype_args) - return :(f(args...)) - end - final_args = Any[:(args[$i]) for i in 1:length(args)] for (idx, T) in sumtype_args final_args[idx] = Symbol("v_", idx)