From fbb2a4ff223e1c8bb7f2d07f47ca652fadd0fd6e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 00:39:41 +0100 Subject: [PATCH 01/15] feat: create AbstractPopMember for customizing PopMember --- src/ConstantOptimization.jl | 8 ++--- src/HallOfFame.jl | 31 ++++++++++++-------- src/MLJInterface.jl | 3 +- src/Migration.jl | 4 +-- src/Mutate.jl | 17 ++++++----- src/Options.jl | 6 ++++ src/OptionsStruct.jl | 2 ++ src/PopMember.jl | 40 ++++++++++++++++++++----- src/Population.jl | 58 ++++++++++++++++++++++++------------- src/SymbolicRegression.jl | 24 ++++++++++----- 10 files changed, 132 insertions(+), 61 deletions(-) diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 585a6ef78..7fa7798d8 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -17,7 +17,7 @@ using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction using ..UtilsModule: get_birth_order, PerTaskCache, stable_get! using ..LossFunctionsModule: eval_loss, loss_to_cost -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember function can_optimize(::AbstractExpression{T}, options) where {T} return can_optimize(T, options) @@ -31,7 +31,7 @@ end member::P, options::AbstractOptions; rng::AbstractRNG=default_rng(), -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,N,P<:AbstractPopMember{T,L,N}} can_optimize(member.tree, options) || return (member, 0.0) nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T,L,N,P<:AbstractPopMember{T,L,N}} tree = member.tree x0, refs = get_scalar_constants(tree) @assert count_constants_for_optimization(tree) == length(x0) @@ -76,7 +76,7 @@ function _optimize_constants( end function _optimize_constants_inner( f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {F,G,T,L,N,P<:AbstractPopMember{T,L,N}} obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing f else diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index c90fbe3a4..4a1b7841e 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -6,12 +6,13 @@ using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf """ - HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N}} List of the best members seen all time in `.members`, with `.members[c]` being the best member seen at complexity c. Including only the members which actually @@ -19,15 +20,19 @@ have been set, you can run `.members[exists]`. # Fields -- `members::Array{PopMember{T,L,N},1}`: List of the best members seen all time. +- `members::Array{PM,1}`: List of the best members seen all time. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct HallOfFame{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} exists::Array{Bool,1} #Whether it has been set end -function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} +function Base.show( + io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N,PM} +) where {T,L,N,PM} println(io, "HallOfFame{...}:") for i in eachindex(hof.members, hof.exists) s_member, s_exists = if hof.exists[i] @@ -47,8 +52,8 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where end return nothing end -function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,HOF<:HallOfFame{T,L,N}} - return PopMember{T,L,N} +function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,PM,HOF<:HallOfFame{T,L,N,PM}} + return PM end """ @@ -69,7 +74,7 @@ function HallOfFame( ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) - return HallOfFame{T,L,typeof(base_tree)}( + return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}( [ PopMember( copy(base_tree), @@ -93,11 +98,10 @@ end """ calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,P}) where {T<:DATA_TYPE,L<:LOSS_TYPE} """ -function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} +function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N,PM}) where {T,L,N,PM} # TODO - remove dataset from args. - P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations - dominating = P[] + dominating = PM[] for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue @@ -276,4 +280,7 @@ function format_hall_of_fame(hof::AbstractVector{<:HallOfFame}, options) end # TODO: Re-use this in `string_dominating_pareto_curve` +# Type accessor for HallOfFame +popmember_type(::Type{<:HallOfFame{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 5845e2db2..b441fda2d 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -39,7 +39,8 @@ using ..CoreModule: ExpressionSpec, get_expression_type, check_warm_start_compatibility -using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..CoreModule.OptionsModule: + DEFAULT_OPTIONS, OPTION_DESCRIPTIONS, default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/Migration.jl b/src/Migration.jl index f7fe61b89..3988b81ea 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -2,7 +2,7 @@ module MigrationModule using ..CoreModule: AbstractOptions using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! +using ..PopMemberModule: AbstractPopMember, PopMember, reset_birth! using ..UtilsModule: poisson_sample """ @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g., """ function migrate!( migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} +) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N,PM}} base_pop = migration.second population_size = length(base_pop.members) mean_number_replaced = population_size * frac diff --git a/src/Mutate.jl b/src/Mutate.jl index f5fd88457..996f474a1 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -61,7 +61,8 @@ This struct encapsulates the result of a mutation operation. Either a new expres Return the `member` if you want to return immediately, and have computed the loss value as part of the mutation. """ -struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} +struct MutationResult{N<:AbstractExpression,P<:AbstractPopMember} <: + AbstractMutationResult{N,P} tree::Union{N,Nothing} member::Union{P,Nothing} num_evals::Float64 @@ -73,7 +74,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes member::Union{_P,Nothing}=nothing, num_evals::Float64=0.0, return_immediately::Bool=false, - ) where {_N<:AbstractExpression,_P<:PopMember} + ) where {_N<:AbstractExpression,_P<:AbstractPopMember} @assert( (tree === nothing) ⊻ (member === nothing), "Mutation result must return either a tree or a pop member, not both" @@ -83,7 +84,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes end """ - condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) + condition_mutation_weights!(weights::AbstractMutationWeights, member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) Adjusts the mutation weights based on the properties of the current member and options. @@ -93,7 +94,7 @@ Note that the weights were already copied, so you don't need to worry about muta # Arguments - `weights::AbstractMutationWeights`: The mutation weights to be adjusted. -- `member::PopMember`: The current population member being mutated. +- `member::AbstractPopMember`: The current population member being mutated. - `options::AbstractOptions`: The options that guide the mutation process. - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. - `nfeatures::Int`: The number of features available in the dataset. @@ -104,7 +105,7 @@ function condition_mutation_weights!( options::AbstractOptions, curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:AbstractExpression,P<:AbstractPopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -181,7 +182,7 @@ end tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} parent_ref = member.ref num_evals = 0.0 @@ -665,7 +666,7 @@ function crossover_generation( curmaxsize::Int, options::AbstractOptions; recorder::RecordType=RecordType(), -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:AbstractPopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false diff --git a/src/Options.jl b/src/Options.jl index cde0d5a31..c746b1332 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -40,6 +40,9 @@ using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutatio import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: @save_kwargs, @ignore + +# Forward declaration - will be defined in PopMemberModule +function default_popmember_type end using ..ExpressionSpecModule: AbstractExpressionSpec, ExpressionSpec, @@ -651,6 +654,7 @@ $(OPTION_DESCRIPTIONS) terminal_width::Union{Nothing,Integer}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", + popmember_type::Type=default_popmember_type(), ### Not search options; just construction options: define_helper_functions::Bool=true, ######################################### @@ -1030,6 +1034,7 @@ $(OPTION_DESCRIPTIONS) expression_type, typeof(expression_options), typeof(set_mutation_weights), + popmember_type, turbo, bumper, deprecated_return_state::Union{Bool,Nothing}, @@ -1103,6 +1108,7 @@ $(OPTION_DESCRIPTIONS) deterministic, define_helper_functions, use_recorder, + popmember_type, ) return options diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 4cf2ffb9e..6f83f89b0 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -181,6 +181,7 @@ struct Options{ E<:AbstractExpression, EO<:NamedTuple, MW<:AbstractMutationWeights, + PM, _turbo, _bumper, _return_state, @@ -254,6 +255,7 @@ struct Options{ deterministic::Bool define_helper_functions::Bool use_recorder::Bool + popmember_type::Type{PM} end function Base.print(io::IO, @nospecialize(options::Options)) diff --git a/src/PopMember.jl b/src/PopMember.jl index bd195a6c2..1f23ff7e8 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -7,8 +7,25 @@ import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost +""" + AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + +Abstract type for population members. Defines the interface that all population members must implement. + +# Required fields (accessed via getproperty/setproperty!) +- `tree::N`: The expression tree +- `cost::L`: The cost including complexity penalty and normalization +- `loss::L`: The raw loss value +- `birth::Int`: Birth order/generation number +- `ref::Int`: Unique reference ID +- `parent::Int`: Parent reference ID +- `complexity::Int`: Cached complexity (accessed via getfield/setfield! for special handling) +""" +abstract type AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} end + # Define a member of population by equation, cost, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} <: + AbstractPopMember{T,L,N} tree::N cost::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -19,7 +36,9 @@ mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} ref::Int parent::Int end -@inline function Base.setproperty!(member::PopMember, field::Symbol, value) + +# Generic interface implementations for AbstractPopMember +@inline function Base.setproperty!(member::AbstractPopMember, field::Symbol, value) if field == :complexity throw( error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") @@ -34,7 +53,7 @@ end end return setfield!(member, field, value) end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty(member::AbstractPopMember, field::Symbol) if field == :complexity throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") @@ -145,7 +164,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:PopMember} +function Base.copy(p::P) where {P<:AbstractPopMember} tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -156,14 +175,14 @@ function Base.copy(p::P) where {P<:PopMember} return P(tree, cost, loss, birth, complexity, ref, parent) end -function reset_birth!(p::PopMember; deterministic::Bool) +function reset_birth!(p::AbstractPopMember; deterministic::Bool) p.birth = get_birth_order(; deterministic) return p end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) @@ -171,11 +190,18 @@ function compute_complexity( return complexity end function recompute_complexity!( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) return complexity end +# Function to extract PopMember type from Population or HallOfFame types +function popmember_type end + +# Default PopMember type for Options +import ..CoreModule.OptionsModule: default_popmember_type +default_popmember_type() = PopMember + end diff --git a/src/Population.jl b/src/Population.jl index 739ca828e..d3bd2b517 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -3,25 +3,29 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree +using ConstructionBase: constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..UtilsModule: bottomk_fast, argmin_fast, PerTaskCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct Population{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} n::Int end """ - Population(pop::Array{PopMember{T,L}, 1}) + Population(pop::Array{<:AbstractPopMember, 1}) Create population from list of PopMembers. """ -function Population(pop::Vector{<:PopMember}) +function Population(pop::Vector{<:AbstractPopMember}) return Population(pop, size(pop, 1)) end @@ -41,23 +45,34 @@ function Population( npop=nothing, ) where {T,L} @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - return Population( - [ - PopMember( + population_size = something(population_size, npop) + PM = options.popmember_type + + # Create first member to get concrete type + first_member = constructorof(PM)( + dataset, + gen_random_tree(nlength, options, nfeatures, T), + options; + parent=-1, + deterministic=options.deterministic, + ) + + # Use the concrete type for the array + members = typeof(first_member)[ + if i == 1 + first_member + else + constructorof(PM)( dataset, gen_random_tree(nlength, options, nfeatures, T), options; parent=-1, deterministic=options.deterministic, - ) for _ in 1:population_size - ], - population_size, - ) + ) + end for i in 1:population_size + ] + + return Population(members, population_size) end """ Population(X::AbstractMatrix{T}, y::AbstractVector{T}; @@ -90,8 +105,8 @@ Create random population and score them on the dataset. ) end -function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) +function Base.copy(pop::P)::P where {T,L,N,PM,P<:Population{T,L,N,PM}} + copied_members = Vector{PM}(undef, pop.n) Threads.@threads for i in 1:(pop.n) copied_members[i] = copy(pop.members[i]) end @@ -118,7 +133,7 @@ function _best_of_sample( members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n adjusted_costs = Vector{L}(undef, n) @@ -218,4 +233,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp ) end +# Type accessor for Population +popmember_type(::Type{<:Population{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index de709ddde..88c791449 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -2,6 +2,7 @@ module SymbolicRegression # Types export Population, + AbstractPopMember, PopMember, HallOfFame, Options, @@ -297,7 +298,7 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: PopMember, reset_birth! +using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -810,10 +811,14 @@ function _preserve_loaded_state!( options::AbstractOptions, ) where {T,L,N} nout = length(state.worker_output) + # Get the prototype to extract types + prototype_pop = state.last_pops[1][1] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + for j in 1:nout, i in 1:(options.populations) - (pop, _, _, _) = extract_from_worker( - state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N} - ) + (pop, _, _, _) = extract_from_worker(state.worker_output[j][i], PopType, HallType) state.last_pops[j][i] = copy(pop) end return nothing @@ -843,11 +848,16 @@ function _warmup_search!( # Multi-threaded doesn't like to fetch within a new task: c_rss = deepcopy(running_search_statistics) last_pop = state.worker_output[j][i] + + # Get the prototype to extract types + prototype_pop = state.last_pops[j][i] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + updated_pop = @sr_spawner( begin - in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) - ) + in_pop = first(extract_from_worker(last_pop, PopType, HallType)) _dispatch_s_r_cycle( in_pop, dataset, From d37dc4505ef6d28e815f64973ba41737ce4dd256 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 01:37:21 +0100 Subject: [PATCH 02/15] refactor: no Base.copy for generic AbstractPopMember --- src/PopMember.jl | 4 ++-- src/SymbolicRegression.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 1f23ff7e8..7c1092dce 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -164,7 +164,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:AbstractPopMember} +function Base.copy(p::PopMember) tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -172,7 +172,7 @@ function Base.copy(p::P) where {P<:AbstractPopMember} complexity = copy(getfield(p, :complexity)) ref = copy(p.ref) parent = copy(p.parent) - return P(tree, cost, loss, birth, complexity, ref, parent) + return PopMember(tree, cost, loss, birth, complexity, ref, parent) end function reset_birth!(p::AbstractPopMember; deterministic::Bool) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 88c791449..acef3e7fb 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -2,7 +2,6 @@ module SymbolicRegression # Types export Population, - AbstractPopMember, PopMember, HallOfFame, Options, From b3cc71c258a44ac88e117906897aa5fc8e6ea5a9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 02:08:49 +0100 Subject: [PATCH 03/15] refactor: better interface for creating new member --- src/Mutate.jl | 93 +++++++++++++++++++---------------------------- src/PopMember.jl | 71 ++++++++++++++++++++++++++++++++++++ src/Population.jl | 3 +- 3 files changed, 109 insertions(+), 58 deletions(-) diff --git a/src/Mutate.jl b/src/Mutate.jl index 996f474a1..03772f9ca 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: AbstractPopMember, PopMember +using ..PopMemberModule: AbstractPopMember, PopMember, create_child using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -254,14 +254,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -278,14 +277,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -322,14 +320,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -340,19 +337,16 @@ end tmp_recorder["reason"] = "pass" end mutation_accepted = true - return ( - PopMember( - tree, - after_cost, - after_loss, - options, - newSize; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + new_member = create_child( + member, + tree, + after_cost, + after_loss, + options; + complexity=newSize, + parent_ref=parent_ref, ) + return (new_member, mutation_accepted, num_evals) end end @@ -583,17 +577,10 @@ function mutate!( simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @recorder recorder["type"] = "simplify" - return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - return_immediately=true, + new_member = create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ) + return MutationResult{N,P}(; member=new_member, return_immediately=true) end function mutate!( @@ -645,14 +632,8 @@ function mutate!( recorder["reason"] = "identity" end return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options, - compute_complexity(tree, options); - parent=parent_ref, - deterministic=options.deterministic, + member=create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ), return_immediately=true, ) @@ -705,23 +686,23 @@ function crossover_generation( ) num_evals += 2 * dataset_fraction(dataset) - baby1 = PopMember( + baby1 = create_child( + (member1, member2), child_tree1::AbstractExpression, after_cost1, after_loss1, - options, - afterSize1; - parent=member1.ref, - deterministic=options.deterministic, + options; + complexity=afterSize1, + parent_ref=member1.ref, )::P - baby2 = PopMember( + baby2 = create_child( + (member1, member2), child_tree2::AbstractExpression, after_cost2, after_loss2, - options, - afterSize2; - parent=member2.ref, - deterministic=options.deterministic, + options; + complexity=afterSize2, + parent_ref=member2.ref, )::P @recorder begin diff --git a/src/PopMember.jl b/src/PopMember.jl index 7c1092dce..b01d6dfb9 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,6 +2,7 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +import DynamicExpressions: constructorof using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order @@ -197,6 +198,74 @@ function recompute_complexity!( return complexity end +# Interface for creating child members with custom field preservation +""" + create_child(parent::P, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + +Create a new PopMember derived from a parent (mutation case). +Custom types should override to preserve their additional fields. + +# Arguments +- `parent`: The parent member being mutated +- `tree`: The new expression tree +- `cost`: The new cost +- `loss`: The new loss +- `options`: The options +- `complexity`: The complexity (computed if not provided) +- `parent_ref`: Reference to parent for tracking +""" +function create_child( + parent::P, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., +) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} + actual_complexity = @something complexity compute_complexity(tree, options) + return constructorof(P)( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +""" + create_child(parents::Tuple{P,P}, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + +Create a new PopMember from two parents (crossover case). +Custom types should override to blend their additional fields. +""" +function create_child( + parents::Tuple{P,P}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., +) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} + actual_complexity = @something complexity compute_complexity(tree, options) + return constructorof(P)( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + # Function to extract PopMember type from Population or HallOfFame types function popmember_type end @@ -204,4 +273,6 @@ function popmember_type end import ..CoreModule.OptionsModule: default_popmember_type default_popmember_type() = PopMember +constructorof(::Type{<:PopMember}) = PopMember + end diff --git a/src/Population.jl b/src/Population.jl index d3bd2b517..ad24ead3d 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -2,8 +2,7 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, string_tree -using ConstructionBase: constructorof +using DynamicExpressions: AbstractExpression, string_tree, constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! From 6d4d64f82e9a649d6db2359d9f094d1a24da053e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 02:10:24 +0100 Subject: [PATCH 04/15] refactor: force custom implementations of `create_child` --- src/PopMember.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index b01d6dfb9..eb3e11820 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -224,7 +224,7 @@ function create_child( complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} +) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} actual_complexity = @something complexity compute_complexity(tree, options) return constructorof(P)( tree, @@ -253,7 +253,7 @@ function create_child( complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} +) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} actual_complexity = @something complexity compute_complexity(tree, options) return constructorof(P)( tree, From bfcf1f4e9c023b2b2963583a6fc795a8a37e4005 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 15:50:14 +0100 Subject: [PATCH 05/15] refactor: add missing parts of AbstractPopMember interface --- src/ExpressionBuilder.jl | 18 ++-- src/HallOfFame.jl | 37 +++++-- src/Mutate.jl | 40 +++---- src/ParametricExpression.jl | 4 +- src/PopMember.jl | 30 ++---- src/SearchUtils.jl | 19 ++-- src/SymbolicRegression.jl | 20 +++- src/TemplateExpression.jl | 6 +- test/runtests.jl | 1 + test/test_abstract_popmember.jl | 186 ++++++++++++++++++++++++++++++++ 10 files changed, 286 insertions(+), 75 deletions(-) create mode 100644 test/test_abstract_popmember.jl diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index db6f5e82b..0867ab47a 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -11,7 +11,8 @@ using DynamicExpressions: using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember, create_child +using ..ComplexityModule: compute_complexity import DynamicExpressions: get_operators import ..CoreModule: create_expression @@ -107,15 +108,16 @@ end return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) end function embed_metadata( - member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L} - return PopMember( + member::PM, options::AbstractOptions, dataset::Dataset{T,L} + ) where {T,L,N,PM<:AbstractPopMember{T,L,N}} + return create_child( + member, embed_metadata(member.tree, options, dataset), member.cost, member.loss, - nothing; - member.ref, - member.parent, + options; + complexity=compute_complexity(member, options), + parent_ref=member.ref, deterministic=options.deterministic, ) end @@ -135,7 +137,7 @@ end end function embed_metadata( vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + ) where {T,L,H<:Union{HallOfFame,Population,AbstractPopMember}} return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec) end end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 4a1b7841e..d18f7fd36 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -73,17 +73,36 @@ function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) + PM = options.popmember_type - return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}( + # Create a prototype member to get the concrete type + prototype = PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + + PMtype = typeof(prototype) + + return HallOfFame{T,L,typeof(base_tree),PMtype}( [ - PopMember( - copy(base_tree), - L(0), - L(Inf), - options; - parent=-1, - deterministic=options.deterministic, - ) for i in 1:(options.maxsize) + if i == 1 + prototype + else + PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + end for i in 1:(options.maxsize) ], [false for i in 1:(options.maxsize)], ) diff --git a/src/Mutate.jl b/src/Mutate.jl index 03772f9ca..412d7ea06 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -40,10 +40,10 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder -abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end +abstract type AbstractMutationResult{N<:AbstractExpression,P<:AbstractPopMember} end """ - MutationResult{N<:AbstractExpression,P<:PopMember} + MutationResult{N<:AbstractExpression,P<:AbstractPopMember} Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. @@ -160,7 +160,7 @@ Use this to modify how `mutate_constant` changes for an expression type. function condition_mutate_constant!( ::Type{<:AbstractExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) @@ -352,7 +352,7 @@ end @generated function _dispatch_mutations!( tree::AbstractExpression, - member::PopMember, + member::AbstractPopMember, mutation_choice::Symbol, weights::W, options::AbstractOptions; @@ -381,7 +381,7 @@ end mutation_weights::AbstractMutationWeights, options::AbstractOptions; kws..., - ) where {N<:AbstractExpression,P<:PopMember,S} + ) where {N<:AbstractExpression,P<:AbstractPopMember,S} Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. Various `kws` are provided to access other data needed for some mutations. @@ -409,7 +409,7 @@ so it can always return immediately. """ function mutate!( ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... -) where {N<:AbstractExpression,P<:PopMember,S} +) where {N<:AbstractExpression,P<:AbstractPopMember,S} return error("Unknown mutation choice: $S") end @@ -422,7 +422,7 @@ function mutate!( recorder::RecordType, temperature, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_constant(tree, temperature, options) @recorder recorder["type"] = "mutate_constant" return MutationResult{N,P}(; tree=tree) @@ -436,7 +436,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_operator(tree, options) @recorder recorder["type"] = "mutate_operator" return MutationResult{N,P}(; tree=tree) @@ -451,7 +451,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_feature(tree, nfeatures) @recorder recorder["type"] = "mutate_feature" return MutationResult{N,P}(; tree=tree) @@ -465,7 +465,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = swap_operands(tree) @recorder recorder["type"] = "swap_operands" return MutationResult{N,P}(; tree=tree) @@ -480,7 +480,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} if rand() < 0.5 tree = append_random_op(tree, options, nfeatures) @recorder recorder["type"] = "add_node:append" @@ -500,7 +500,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = insert_random_op(tree, options, nfeatures) @recorder recorder["type"] = "insert_node" return MutationResult{N,P}(; tree=tree) @@ -514,7 +514,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = delete_random_op!(tree) @recorder recorder["type"] = "delete_node" return MutationResult{N,P}(; tree=tree) @@ -528,7 +528,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = form_random_connection!(tree) @recorder recorder["type"] = "form_connection" return MutationResult{N,P}(; tree=tree) @@ -542,7 +542,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = break_random_connection!(tree) @recorder recorder["type"] = "break_connection" return MutationResult{N,P}(; tree=tree) @@ -556,7 +556,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = randomly_rotate_tree!(tree) @recorder recorder["type"] = "rotate_tree" return MutationResult{N,P}(; tree=tree) @@ -572,7 +572,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @assert options.should_simplify simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @@ -593,7 +593,7 @@ function mutate!( curmaxsize, nfeatures, kws..., -) where {T,N<:AbstractExpression{T},P<:PopMember} +) where {T,N<:AbstractExpression{T},P<:AbstractPopMember} tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) @@ -608,7 +608,7 @@ function mutate!( recorder::RecordType, dataset::Dataset, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} cur_member, new_num_evals = optimize_constants(dataset, member, options) @recorder recorder["type"] = "optimize" return MutationResult{N,P}(; @@ -625,7 +625,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @recorder begin recorder["type"] = "identity" recorder["result"] = "accept" diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 2717afbdc..b7c9ab2f6 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -24,7 +24,7 @@ using ..CoreModule: AbstractExpressionSpec, get_indices, ExpressionSpecModule as ES -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB @@ -102,7 +102,7 @@ end function MM.condition_mutate_constant!( ::Type{<:ParametricExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/src/PopMember.jl b/src/PopMember.jl index eb3e11820..698fa9627 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -198,35 +198,25 @@ function recompute_complexity!( return complexity end -# Interface for creating child members with custom field preservation """ - create_child(parent::P, tree, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember - -Create a new PopMember derived from a parent (mutation case). -Custom types should override to preserve their additional fields. + create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where {T,L,P<:PopMember{T,L}} -# Arguments -- `parent`: The parent member being mutated -- `tree`: The new expression tree -- `cost`: The new cost -- `loss`: The new loss -- `options`: The options -- `complexity`: The complexity (computed if not provided) -- `parent_ref`: Reference to parent for tracking +Create a new PopMember with a potentially different expression type. +Used by embed_metadata where the expression gains metadata. """ function create_child( parent::P, - tree::N, + tree::AbstractExpression{T}, cost::L, loss::L, options; complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} +) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) - return constructorof(P)( + return PopMember( tree, cost, loss, @@ -246,16 +236,16 @@ Custom types should override to blend their additional fields. """ function create_child( parents::Tuple{P,P}, - tree::N, + tree::AbstractExpression{T}, cost::L, loss::L, options; complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} +) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) - return constructorof(P)( + return PopMember( tree, cost, loss, diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 89eaa8cbf..014b036c4 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -17,7 +17,7 @@ using ..UtilsModule: subscriptify using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -583,8 +583,9 @@ The state of the search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <: - AbstractSearchState{T,L,N} +Base.@kwdef struct SearchState{ + T,L,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N},WorkerOutputType,ChannelType +} <: AbstractSearchState{T,L,N} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -592,16 +593,16 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha channels::Vector{Vector{ChannelType}} worker_assignment::WorkerAssignments task_order::Vector{Tuple{Int,Int}} - halls_of_fame::Vector{HallOfFame{T,L,N}} - last_pops::Vector{Vector{Population{T,L,N}}} - best_sub_pops::Vector{Vector{Population{T,L,N}}} + halls_of_fame::Vector{HallOfFame{T,L,N,PM}} + last_pops::Vector{Vector{Population{T,L,N,PM}}} + best_sub_pops::Vector{Vector{Population{T,L,N,PM}}} all_running_search_statistics::Vector{RunningSearchStatistics} num_evals::Vector{Vector{Float64}} cycles_remaining::Vector{Int} cur_maxsizes::Vector{Int} stdin_reader::StdinReader record::Base.RefValue{RecordType} - seed_members::Vector{Vector{PopMember{T,L,N}}} + seed_members::Vector{Vector{PM}} end function save_to_file( @@ -718,7 +719,7 @@ end function update_hall_of_fame!( hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions -) where {PM<:PopMember} +) where {PM<:AbstractPopMember} for member in members size = compute_complexity(member, options) valid_size = 0 < size <= options.maxsize @@ -793,7 +794,7 @@ function parse_guesses( guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L},D<:Dataset{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] guess_lists = _make_vector_vector(guesses, nout) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index acef3e7fb..ef7c86f0a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -298,6 +298,7 @@ using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type +using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -633,8 +634,19 @@ end example_dataset = first(datasets) example_ex = create_expression(init_value(T), options, example_dataset) NT = typeof(example_ex) - PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} + # Create a prototype member to get the concrete type + prototype_member = options.popmember_type( + copy(example_ex), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + PMType = typeof(prototype_member) + PopType = Population{T,L,NT,PMType} + HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( Val(ropt.parallelism), PopType, HallOfFameType ) @@ -692,9 +704,9 @@ end j in 1:nout ] - seed_members = [PopMember{T,L,NT}[] for j in 1:nout] + seed_members = [Vector{PMType}() for j in 1:nout] - return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; + return SearchState{T,L,NT,PMType,WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index af86f4825..0f87cb9b0 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -52,7 +52,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector struct ParamVector{T} <: AbstractVector{T} @@ -745,7 +745,7 @@ function MM.condition_mutation_weights!( @nospecialize(options::AbstractOptions), curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:TemplateExpression,P<:AbstractPopMember{T,L,N}} if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 @@ -828,7 +828,7 @@ end function MM.condition_mutate_constant!( ::Type{<:TemplateExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/test/runtests.jl b/test/runtests.jl index 7aef02fea..f4bd81111 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -164,6 +164,7 @@ end end include("test_abstract_numbers.jl") +include("test_abstract_popmember.jl") include("test_logging.jl") include("test_pretty_printing.jl") diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl new file mode 100644 index 000000000..318d260a9 --- /dev/null +++ b/test/test_abstract_popmember.jl @@ -0,0 +1,186 @@ +@testitem "Custom AbstractPopMember implementation" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + + import SymbolicRegression.PopMemberModule: create_child + + # Define a custom PopMember that tracks generation count + mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + generation::Int # Custom field to track generation + end + + # # Direct constructor that matches field order + function CustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + generation::Int, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, generation + ) + end + + function CustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + # Constructor for Population initialization (dataset, tree, options) + function CustomPopMember( + dataset::SymbolicRegression.Dataset, tree, options; parent=-1, deterministic=nothing + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost( + dataset, ex, options; complexity=complexity + ) + + return CustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + + # Define copy for CustomPopMember + function Base.copy(p::CustomPopMember) + return CustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.generation), + ) + end + + function create_child( + parent::CustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.generation + 1, + ) + end + + function create_child( + parents::Tuple{<:CustomPopMember,<:CustomPopMember}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., + ) where {T,L,N<:AbstractExpression{T}} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + max_generation = max(parents[1].generation, parents[2].generation) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.CoreModule.UtilsModule.get_birth_order(; + deterministic=options.deterministic + ), + actual_complexity, + abs(rand(Int)), + parent_ref, + max_generation + 1, + ) + end + + # Test that we can run equation_search with CustomPopMember + X = randn(Float32, 2, 100) + y = @. X[1, :]^2 - X[2, :] + + options = SymbolicRegression.Options(; + binary_operators=[+, -], + populations=1, + population_size=20, + maxsize=5, + popmember_type=CustomPopMember, + deterministic=true, + seed=0, + ) + + # Test that options were created with correct type + @test options.popmember_type == CustomPopMember + + hall_of_fame = equation_search( + X, y; options=options, niterations=2, parallelism=:serial + ) + + # Verify that we got results + @test sum(hall_of_fame.exists) > 0 + + # Verify that the members are CustomPopMember + for i in eachindex(hall_of_fame.members, hall_of_fame.exists) + if hall_of_fame.exists[i] + @test hall_of_fame.members[i] isa CustomPopMember + # Check that generation field exists + @test hall_of_fame.members[i].generation >= 0 + end + end + + # Verify we can extract the best member + best_idx = findlast(hall_of_fame.exists) + @test !isnothing(best_idx) + best_member = hall_of_fame.members[best_idx] + @test best_member isa CustomPopMember +end From 690e0f1ef37e1fda545a15bb063e1e1bbf07f6b8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 15:53:40 +0100 Subject: [PATCH 06/15] refactor: move forward forward decl --- src/Options.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Options.jl b/src/Options.jl index c746b1332..977c3fc96 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -40,9 +40,6 @@ using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutatio import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: @save_kwargs, @ignore - -# Forward declaration - will be defined in PopMemberModule -function default_popmember_type end using ..ExpressionSpecModule: AbstractExpressionSpec, ExpressionSpec, @@ -227,6 +224,8 @@ recommend_loss_function_expression(expression_type) = false create_mutation_weights(w::AbstractMutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) +function default_popmember_type end + @unstable function with_max_degree_from_context( node_type, user_provided_operators, operators ) From 78b6b80a2c5536cff8c9aa588498dad7e48f7e87 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 16:15:58 +0100 Subject: [PATCH 07/15] refactor: move imports to top --- src/PopMember.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 698fa9627..050e2344f 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -4,6 +4,7 @@ using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree import DynamicExpressions: constructorof using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression +import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost @@ -259,10 +260,7 @@ end # Function to extract PopMember type from Population or HallOfFame types function popmember_type end -# Default PopMember type for Options -import ..CoreModule.OptionsModule: default_popmember_type default_popmember_type() = PopMember - constructorof(::Type{<:PopMember}) = PopMember end From fd3f5a5f82b24ffded01f89b16e2ce26242c1a06 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 18:16:52 +0100 Subject: [PATCH 08/15] fix: mark unstable --- src/PopMember.jl | 4 ++-- test/test_abstract_popmember.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 050e2344f..7e5ca84e9 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -260,7 +260,7 @@ end # Function to extract PopMember type from Population or HallOfFame types function popmember_type end -default_popmember_type() = PopMember -constructorof(::Type{<:PopMember}) = PopMember +@unstable default_popmember_type() = PopMember +@unstable constructorof(::Type{<:PopMember}) = PopMember end diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 318d260a9..7e6aff8b5 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -2,6 +2,7 @@ using SymbolicRegression using DynamicExpressions using Test + using DispatchDoctor: @unstable import SymbolicRegression.PopMemberModule: create_child @@ -76,7 +77,7 @@ ) end - DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember # Define copy for CustomPopMember function Base.copy(p::CustomPopMember) From 26665da252a2c87a6e2dee01f5e2f856a7105ce7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:16:36 +0100 Subject: [PATCH 09/15] fix: `parse_guesses` for custom AbstractPopMember --- src/ExpressionBuilder.jl | 1 - src/MLJInterface.jl | 4 ++-- src/PopMember.jl | 12 ++++++++---- src/SearchUtils.jl | 31 +++++++++++++++++++++++++++---- test/test_abstract_popmember.jl | 2 -- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 33b1cc09e..99de5bdc6 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -118,7 +118,6 @@ end options; complexity=compute_complexity(member, options), parent_ref=member.ref, - deterministic=options.deterministic, ) end function embed_metadata( diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 4ca68ed65..ccc6c33ee 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -39,8 +39,8 @@ using ..CoreModule: ExpressionSpec, get_expression_type, check_warm_start_compatibility -using ..CoreModule.OptionsModule: - DEFAULT_OPTIONS, OPTION_DESCRIPTIONS, default_popmember_type +using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..PopMemberModule: default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/PopMember.jl b/src/PopMember.jl index 7e5ca84e9..ec350686e 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -201,7 +201,7 @@ end """ create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where {T,L,P<:PopMember{T,L}} + complexity::Union{Int,Nothing}=nothing, parent_ref) where {T,L,P<:PopMember{T,L}} Create a new PopMember with a potentially different expression type. Used by embed_metadata where the expression gains metadata. @@ -214,7 +214,6 @@ function create_child( options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) return PopMember( @@ -230,7 +229,7 @@ end """ create_child(parents::Tuple{P,P}, tree, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + complexity::Union{Int,Nothing}=nothing, parent_ref) where P<:AbstractPopMember Create a new PopMember from two parents (crossover case). Custom types should override to blend their additional fields. @@ -243,7 +242,6 @@ function create_child( options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) return PopMember( @@ -263,4 +261,10 @@ function popmember_type end @unstable default_popmember_type() = PopMember @unstable constructorof(::Type{<:PopMember}) = PopMember +@inline function with_expression_type( + ::Type{<:PopMember{T,L}}, ::Type{N} +) where {T,L,N<:AbstractExpression{T}} + return PopMember{T,L,N} +end + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ae248951a..1c71d3723 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -12,12 +12,18 @@ using DispatchDoctor: @unstable using Logging: AbstractLogger using DynamicExpressions: - AbstractExpression, string_tree, parse_expression, EvalOptions, with_type_parameters + AbstractExpression, + string_tree, + parse_expression, + EvalOptions, + with_type_parameters, + constructorof using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features +using ..CoreModule: + Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember, AbstractPopMember +using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -800,7 +806,9 @@ function parse_guesses( dataset = datasets[j] for g in guess_lists[j] ex = _parse_guess_expression(T, g, dataset, options) - member = PopMember(dataset, ex, options; deterministic=options.deterministic) + member = constructorof(P)( + dataset, ex, options; deterministic=options.deterministic + ) if options.should_optimize_constants member, _ = optimize_constants(dataset, member, options) end @@ -818,6 +826,21 @@ function parse_guesses( end return out end + +# Deal with non-concrete PopMember types +function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N === Any && error("Failed to infer expression type") + ConcreteP = with_expression_type(P, N) + return parse_guesses(ConcreteP, guesses, datasets, options) +end + function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 7e6aff8b5..058cbc02c 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -101,7 +101,6 @@ options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L} actual_complexity = @something complexity SymbolicRegression.compute_complexity( tree, options @@ -126,7 +125,6 @@ options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,N<:AbstractExpression{T}} actual_complexity = @something complexity SymbolicRegression.compute_complexity( tree, options From f1053282ea45125c3dce31d998fd8e586f9098e1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:26:02 +0100 Subject: [PATCH 10/15] fix: mark unstable to avoid recursion --- src/SearchUtils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 1c71d3723..461f0a1da 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -793,7 +793,7 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -function parse_guesses( +@unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, @@ -828,7 +828,7 @@ function parse_guesses( end # Deal with non-concrete PopMember types -function parse_guesses( +@unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, From f45f0a1115a570ea3295402f4692e1453610d138 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:51:37 +0100 Subject: [PATCH 11/15] fix: allow `_get_cost` to be generic --- src/Population.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Population.jl b/src/Population.jl index ad24ead3d..00f603258 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -171,7 +171,7 @@ function _best_of_sample( end return members[chosen_idx] end -_get_cost(member::PopMember) = member.cost +_get_cost(member::AbstractPopMember) = member.cost const CACHED_WEIGHTS = let init_k = collect(0:5), From 2dc06456e3b1ecf0832d1cc67597eb3aed9120f8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 13:32:16 +0100 Subject: [PATCH 12/15] fix: try to avoid recursive type inference --- src/SearchUtils.jl | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 461f0a1da..ced0707c5 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -798,6 +798,29 @@ end guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} + return _parse_guesses_impl(P, guesses, datasets, options) +end + +# Deal with non-concrete PopMember types +@unstable function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + ConcreteP = with_expression_type(P, N) + return _parse_guesses_impl(ConcreteP, guesses, datasets, options) +end + +@inline function _parse_guesses_impl( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, ) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] @@ -827,20 +850,6 @@ end return out end -# Deal with non-concrete PopMember types -@unstable function parse_guesses( - ::Type{P}, - guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, - datasets::Vector{D}, - options::AbstractOptions, -) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} - NodeType = with_type_parameters(options.node_type, T) - N = Base.promote_op(create_expression, NodeType, typeof(options), D) - N === Any && error("Failed to infer expression type") - ConcreteP = with_expression_type(P, N) - return parse_guesses(ConcreteP, guesses, datasets, options) -end - function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} From 4d3e1cf371717facfc92fbbeb2e6a00a42d50be1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 19:19:33 +0100 Subject: [PATCH 13/15] refactor: `infer_popmember_type` --- src/PopMember.jl | 12 +++++++++++- src/SearchUtils.jl | 26 +++++++++++--------------- src/SymbolicRegression.jl | 22 ++++++---------------- test/test_abstract_popmember.jl | 7 +++++++ 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index ec350686e..71f8707de 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,7 +2,7 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree -import DynamicExpressions: constructorof +import DynamicExpressions: constructorof, with_type_parameters using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity @@ -267,4 +267,14 @@ function popmember_type end return PopMember{T,L,N} end +@inline function with_type_parameters( + ::Type{<:PopMember}, ::Type{T}, ::Type{L}, ::Type{N} +) where {T,L,N} + return PopMember{T,L,N} +end + +@inline function expression_type(::Type{<:AbstractPopMember{<:Any,<:Any,N}}) where {N} + return N +end + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ced0707c5..05efe4c09 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -23,7 +23,7 @@ using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -34,6 +34,15 @@ using ..CheckConstraintsModule: check_constraints function logging_callback! end +@unstable @inline function infer_popmember_type( + ::Type{T}, ::Type{L}, ::Type{D}, options +) where {T,L,D<:Dataset} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + return with_type_parameters(options.popmember_type, T, L, N) +end + """ @filtered_async expr @@ -793,26 +802,13 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -@unstable function parse_guesses( - ::Type{P}, - guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, - datasets::Vector{D}, - options::AbstractOptions, -) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} - return _parse_guesses_impl(P, guesses, datasets, options) -end - -# Deal with non-concrete PopMember types @unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, ) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} - NodeType = with_type_parameters(options.node_type, T) - N = Base.promote_op(create_expression, NodeType, typeof(options), D) - N in (Any, Union{}) && error("Failed to infer expression type") - ConcreteP = with_expression_type(P, N) + ConcreteP = infer_popmember_type(T, L, D, options) return _parse_guesses_impl(ConcreteP, guesses, datasets, options) end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 40119fc70..4e9443f23 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -297,7 +297,8 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type +using .PopMemberModule: + AbstractPopMember, PopMember, reset_birth!, popmember_type, expression_type using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: @@ -339,7 +340,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame!, parse_guesses, - logging_callback! + logging_callback!, + infer_popmember_type using .LoggingModule: AbstractSRLogger, SRLogger, get_logger using .TemplateExpressionModule: TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params @@ -631,20 +633,8 @@ end @recorder record["options"] = "$(options)" nout = length(datasets) - example_dataset = first(datasets) - example_ex = create_expression(init_value(T), options, example_dataset) - NT = typeof(example_ex) - # Create a prototype member to get the concrete type - prototype_member = options.popmember_type( - copy(example_ex), - L(0), - L(Inf), - options, - 1; # complexity - parent=-1, - deterministic=options.deterministic, - ) - PMType = typeof(prototype_member) + PMType = infer_popmember_type(T, L, D, options) + NT = expression_type(PMType) PopType = Population{T,L,NT,PMType} HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 058cbc02c..eabc6886d 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -79,6 +79,13 @@ @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + # Define with_type_parameters for CustomPopMember + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:CustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return CustomPopMember{T,L,N} + end + # Define copy for CustomPopMember function Base.copy(p::CustomPopMember) return CustomPopMember( From 6e9928660ca78aeb1c760274b672f43396452a1e Mon Sep 17 00:00:00 2001 From: Atharva Sehgal Date: Sun, 12 Oct 2025 14:15:37 +0000 Subject: [PATCH 14/15] Generalize row collation in HallOfFame.jl. Presently, HallOfFame is overfit to only a select few metrics that can possibly be tracked using subclasses of AbstractPopMember. This (and following) commits attempt to generalize the framework by allowing users to add arbitrary metrics in rows. This also opens up the API to Tables.jl through the SymbolicRegressionTablesExt.jl file which will, amongst other things, allow us to save the hall-of-fame as a DataFrame.jl or a JuliaDB. --- Project.toml | 3 + ext/SymbolicRegressionTablesExt.jl | 12 + src/HallOfFame.jl | 264 ++++++++++++++++++--- test/runtests.jl | 2 + test/test_hof_rows.jl | 369 +++++++++++++++++++++++++++++ 5 files changed, 615 insertions(+), 35 deletions(-) create mode 100644 ext/SymbolicRegressionTablesExt.jl create mode 100644 test/test_hof_rows.jl diff --git a/Project.toml b/Project.toml index f5a02e829..0e1cc78f5 100644 --- a/Project.toml +++ b/Project.toml @@ -36,12 +36,14 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [extensions] SymbolicRegressionEnzymeExt = "Enzyme" SymbolicRegressionJSON3Ext = "JSON3" SymbolicRegressionMooncakeExt = "Mooncake" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" +SymbolicRegressionTablesExt = "Tables" [compat] ADTypes = "^1.4.0" @@ -74,4 +76,5 @@ StatsBase = "0.33, 0.34" StyledStrings = "1" SymbolicUtils = "0.19, ^1.0.5, 2, 3" TOML = "<0.0.1, 1" +Tables = "1" julia = "1.10" diff --git a/ext/SymbolicRegressionTablesExt.jl b/ext/SymbolicRegressionTablesExt.jl new file mode 100644 index 000000000..75a69d366 --- /dev/null +++ b/ext/SymbolicRegressionTablesExt.jl @@ -0,0 +1,12 @@ +module SymbolicRegressionTablesExt + +using Tables: Tables +import SymbolicRegression.HallOfFameModule: HOFRows + +# Make HOFRows compatible with the Tables.jl interface +# HOFRows is already iterable via Base.iterate, so we just need to declare compatibility +Tables.istable(::Type{<:HOFRows}) = true +Tables.rowaccess(::Type{<:HOFRows}) = true +Tables.rows(view::HOFRows) = view # Return itself since it's already iterable + +end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index d18f7fd36..3844cfd90 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -2,6 +2,7 @@ module HallOfFameModule using StyledStrings: @styled_str using DynamicExpressions: AbstractExpression, string_tree +using DispatchDoctor: @unstable using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value @@ -146,6 +147,170 @@ function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N,PM}) where {T,L, return dominating end +""" + member_to_row(member::AbstractPopMember, dataset::Dataset, options::AbstractOptions; + pretty::Bool=true) + +Convert a PopMember to a row representation for display/export. + +This is the primary extension point for custom PopMember types. Users can override this +method to include additional fields in the output. + +# Arguments +- `member`: The population member to convert +- `dataset`: Dataset for formatting equation strings +- `options`: Options controlling complexity and equation formatting +- `pretty`: Whether to use pretty-printing for equations (default: true) + +# Returns +A NamedTuple containing the member's data. Default fields are: +- `complexity`: Expression complexity +- `loss`: Raw loss value +- `cost`: Cost including complexity penalty +- `birth`: Birth order/generation +- `ref`: Unique reference ID +- `parent`: Parent reference ID +- `equation`: Formatted equation string + +# Example: Adding custom fields +```julia +function SymbolicRegression.HallOfFameModule.member_to_row( + member::MyCustomPopMember, + dataset::Dataset, + options::AbstractOptions; + kwargs... +) + base = invoke(member_to_row, Tuple{AbstractPopMember, Dataset, AbstractOptions}, + member, dataset, options; kwargs...) + return merge(base, (my_field = member.custom_data,)) +end +``` +""" +function member_to_row( + member::AbstractPopMember, dataset::Dataset, options::AbstractOptions; pretty::Bool=true +) + eqn_string = string_tree( + member.tree, + options; + display_variable_names=dataset.display_variable_names, + X_sym_units=dataset.X_sym_units, + y_sym_units=dataset.y_sym_units, + pretty=pretty, + ) + prefix = make_prefix(member.tree, options, dataset) + eqn_string = prefix * eqn_string + return ( + complexity=compute_complexity(member, options), + loss=member.loss, + cost=member.cost, + birth=member.birth, + ref=member.ref, + parent=member.parent, + equation=eqn_string, + ) +end + +""" + HOFRows + +A lazy iterator for HallOfFame members that computes rows on-demand. +This struct implements the Tables.jl interface for easy export to DataFrames, CSV, etc. + +# Fields +- `members`: Vector of PopMembers to iterate over +- `dataset`: Dataset for formatting equations +- `options`: Options for complexity and formatting +- `include_score`: Whether to compute and include Pareto improvement scores +- `pretty`: Whether to use pretty-printing for equations +""" +struct HOFRows{PM<:AbstractPopMember} + members::Vector{PM} + dataset::Dataset + options::AbstractOptions + include_score::Bool + pretty::Bool +end + +# Helper function to create a single row with optional score +@unstable function _make_row(view::HOFRows, i::Int, scores) + row = member_to_row(view.members[i], view.dataset, view.options; pretty=view.pretty) + + return scores === nothing ? row : (; row..., score=scores[i]) +end + +# Make HOFRows iterable +Base.length(view::HOFRows) = length(view.members) +Base.eltype(::Type{<:HOFRows}) = NamedTuple + +function Base.iterate(view::HOFRows) + isempty(view.members) && return nothing + + # Compute all scores upfront if needed + scores = view.include_score ? compute_scores(view.members, view.options) : nothing + state = (scores, 1) + + row = _make_row(view, 1, scores) + return (row, state) +end + +function Base.iterate(view::HOFRows, state) + scores, i = state + i += 1 + i > length(view.members) && return nothing + + row = _make_row(view, i, scores) + return (row, (scores, i)) +end + +""" + hof_rows(hof::HallOfFame, dataset::Dataset, options::AbstractOptions; + pareto_only::Bool=true, include_score::Bool=pareto_only, + pretty::Bool=true) + +This function returns an `HOFRows` object. + +# Arguments +- `hof`: The HallOfFame to export +- `dataset`: Dataset for formatting equations +- `options`: Options controlling complexity and formatting +- `pareto_only`: Only include Pareto frontier members (default: true) +- `include_score`: Include Pareto improvement scores (default: same as `pareto_only`) +- `pretty`: Use pretty-printing for equations (default: true) + +# Returns +An `HOFRows` object that can be used with Tables.jl-compatible consumers like +`DataFrame`, `CSV.write`, etc. + +# Examples +```julia +# Get a Tables.jl view of the Pareto frontier +rows = hof_rows(hof, dataset, options) + +# Convert to DataFrame (requires DataFrames.jl) +using DataFrames +df = DataFrame(rows) + +# Get all members without scores +all_rows = hof_rows(hof, dataset, options; pareto_only=false, include_score=false) +``` +""" +function hof_rows( + hof::HallOfFame, + dataset::Dataset, + options::AbstractOptions; + pareto_only::Bool=true, + include_score::Bool=pareto_only, + pretty::Bool=true, +) + members = if pareto_only + calculate_pareto_frontier(hof) + else + [m for (m, ex) in zip(hof.members, hof.exists) if ex] + end + + return HOFRows(members, dataset, options, include_score, pretty) +end + let header_parts = ( rpad(styled"{bold:{underline:Complexity}}", 10), rpad(styled"{bold:{underline:Loss}}", 9), @@ -170,23 +335,21 @@ function string_dominating_pareto_curve( println(buffer, HEADER_WITHOUT_SCORE) end - formatted = format_hall_of_fame(hallOfFame, options) - for (tree, score, loss, complexity) in - zip(formatted.trees, formatted.scores, formatted.losses, formatted.complexities) - eqn_string = string_tree( - tree, - options; - display_variable_names=dataset.display_variable_names, - X_sym_units=dataset.X_sym_units, - y_sym_units=dataset.y_sym_units, - pretty, - ) - prefix = make_prefix(tree, options, dataset) - eqn_string = prefix * eqn_string + # Use hof_rows to get data with scores but without prefix + # (we need to format prefix specially for wrapping) + rows_view = hof_rows( + hallOfFame, dataset, options; pareto_only=true, include_score=true, pretty=pretty + ) + members = rows_view.members + + for (i, row) in enumerate(rows_view) + member = members[i] + prefix = make_prefix(member.tree, options, dataset) + eqn_string = row.equation stats_columns_string = if show_score_column(options) - @sprintf("%-10d %-8.3e %-8.3e ", complexity, loss, score) + @sprintf("%-10d %-8.3e %-8.3e ", row.complexity, row.loss, row.score) else - @sprintf("%-10d %-8.3e ", complexity, loss) + @sprintf("%-10d %-8.3e ", row.complexity, row.loss) end left_cols_width = length(stats_columns_string) print(buffer, stats_columns_string) @@ -237,31 +400,38 @@ function wrap_equation_string(eqn_string, left_cols_width, terminal_width) return dump_buffer(buffer) end -function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} - dominating = calculate_pareto_frontier(hof) +""" + compute_scores(members::Vector{<:AbstractPopMember}, options::AbstractOptions) - # Only check for negative losses if using logarithmic scaling - options.loss_scale == :log && for member in dominating - if member.loss < 0.0 - throw( - DomainError( - member.loss, - "Your loss function must be non-negative. To allow negative losses, set the `loss_scale` to linear, or consider wrapping your loss inside an exponential.", - ), - ) - end - end +Compute improvement scores for an ordered sequence of members. - trees = [member.tree for member in dominating] - losses = [member.loss for member in dominating] - complexities = [compute_complexity(member, options) for member in dominating] - scores = Array{L}(undef, length(dominating)) +Scores measure the improvement in loss per unit complexity compared to the previous +member in the sequence. The first member always has a score of zero. + +This function works with any ordered sequence of members (e.g., Pareto frontier, +complexity-sorted members, etc.). + +# Arguments +- `members`: Vector of PopMembers in the desired order +- `options`: Options controlling the loss scale (`:linear` or `:log`) + +# Returns +Vector of scores with the same length as `members` +""" +function compute_scores( + members::Vector{<:AbstractPopMember{T,L,N}}, options::AbstractOptions +) where {T,L,N} + isempty(members) && return L[] + + scores = Vector{L}(undef, length(members)) - cur_loss = typemax(L) - last_loss = cur_loss + complexities = [compute_complexity(member, options) for member in members] + losses = [member.loss for member in members] + + last_loss = typemax(L) last_complexity = zero(eltype(complexities)) - for i in 1:length(dominating) + for i in eachindex(members) complexity = complexities[i] cur_loss = losses[i] delta_c = complexity - last_complexity @@ -277,6 +447,30 @@ function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} last_loss = cur_loss last_complexity = complexity end + + return scores +end + +function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} + dominating = calculate_pareto_frontier(hof) + + # Only check for negative losses if using logarithmic scaling + options.loss_scale == :log && for member in dominating + if member.loss < 0.0 + throw( + DomainError( + member.loss, + "Your loss function must be non-negative. To allow negative losses, set the `loss_scale` to linear, or consider wrapping your loss inside an exponential.", + ), + ) + end + end + + trees = [member.tree for member in dominating] + losses = [member.loss for member in dominating] + complexities = [compute_complexity(member, options) for member in dominating] + scores = compute_scores(dominating, options) + return (; trees, scores, losses, complexities) end function compute_direct_score(cur_loss, last_loss, delta_c) diff --git a/test/runtests.jl b/test/runtests.jl index f4bd81111..61e2362b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,6 +87,8 @@ include("test_options.jl") include("test_hash.jl") end +include("test_hof_rows.jl") + @testitem "Test migration" tags = [:part3] begin include("test_migration.jl") end diff --git a/test/test_hof_rows.jl b/test/test_hof_rows.jl new file mode 100644 index 000000000..d8239850f --- /dev/null +++ b/test/test_hof_rows.jl @@ -0,0 +1,369 @@ +@testitem "HOF rows functionality" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + + # Create test data + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], + unary_operators=[], + maxsize=5, + populations=1, + population_size=10, + tournament_selection_n=3, + deterministic=true, + seed=0, + ) + + dataset = Dataset(X, y) + + @testset "compute_scores" begin + # Create a simple HOF with multiple members + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + + # Add multiple members with different complexities + for i in 1:3 + hof.exists[i] = true + end + + members = [hof.members[i] for i in 1:3 if hof.exists[i]] + + # Test score computation + scores = SymbolicRegression.HallOfFameModule.compute_scores(members, options) + + @test length(scores) == length(members) + @test scores[1] == 0 # First member always has score 0 + @test all(s >= 0 for s in scores) # Scores should be non-negative + + # Test with empty members + empty_scores = SymbolicRegression.HallOfFameModule.compute_scores( + typeof(members[1])[], options + ) + @test isempty(empty_scores) + end + + @testset "HOFRows iteration" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, include_score=true + ) + + # Test Base.length + @test length(rows) == 2 + + # Test Base.eltype + @test eltype(rows) == NamedTuple + + # Test iteration + collected = collect(rows) + @test length(collected) == 2 + @test all(r isa NamedTuple for r in collected) + + # Test that scores are included by default for pareto_only=true + @test all(haskey(r, :score) for r in collected) + + # Test equation inclusion + @test all(haskey(r, :equation) for r in collected) + end + + @testset "hof_rows options" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + for i in 1:3 + hof.exists[i] = true + end + + # Test pareto_only=false + rows_all = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false + ) + # Should include all existing members (Pareto might filter some) + @test length(rows_all) == 3 + + # Test include_score=false + rows_no_score = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; include_score=false + ) + for row in rows_no_score + @test !haskey(row, :score) + end + end + + @testset "Empty HOF" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + # Don't mark any as existing + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + + @test length(rows) == 0 + @test isempty(collect(rows)) + end + + @testset "Backwards compatibility" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + # Test that format_hall_of_fame still works + formatted = SymbolicRegression.HallOfFameModule.format_hall_of_fame(hof, options) + + @test haskey(formatted, :trees) + @test haskey(formatted, :scores) + @test haskey(formatted, :losses) + @test haskey(formatted, :complexities) + @test length(formatted.trees) == length(formatted.scores) + @test length(formatted.trees) == length(formatted.losses) + @test length(formatted.trees) == length(formatted.complexities) + + # Test that string_dominating_pareto_curve still works + curve_string = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options + ) + + @test curve_string isa AbstractString + @test contains(curve_string, "Complexity") + @test contains(curve_string, "Loss") + end +end + +@testitem "HOF rows with custom PopMember" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + using DispatchDoctor: @unstable + + import SymbolicRegression.PopMemberModule: create_child + import SymbolicRegression.HallOfFameModule: member_to_row + + # Define a custom PopMember with an extra field + mutable struct TestCustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + custom_field::Float64 # Extra field + end + + # Constructor + function TestCustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + custom_field::Float64, + ) where {T,L,N<:AbstractExpression{T}} + return TestCustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, custom_field + ) + end + + function TestCustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + custom_field=1.0, + ) where {T,L,N<:AbstractExpression{T}} + return TestCustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + custom_field, + ) + end + + function TestCustomPopMember( + dataset::SymbolicRegression.Dataset, + tree, + options; + parent=-1, + deterministic=nothing, + custom_field=1.0, + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost(dataset, ex, options; complexity) + + return TestCustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + custom_field, + ) + end + + @unstable DynamicExpressions.constructorof(::Type{<:TestCustomPopMember}) = + TestCustomPopMember + + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:TestCustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return TestCustomPopMember{T,L,N} + end + + function Base.copy(p::TestCustomPopMember) + return TestCustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.custom_field), + ) + end + + function create_child( + parent::CustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return TestCustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.custom_field * 1.1, # Modify custom field + ) + end + + # Extend member_to_row for custom PopMember + function member_to_row( + member::TestCustomPopMember, + dataset::SymbolicRegression.Dataset, + options::SymbolicRegression.AbstractOptions; + kwargs..., + ) + base = invoke( + member_to_row, + Tuple{ + SymbolicRegression.AbstractPopMember, + SymbolicRegression.Dataset, + SymbolicRegression.AbstractOptions, + }, + member, + dataset, + options; + kwargs..., + ) + return merge(base, (custom_field=member.custom_field,)) + end + + @testset "Custom PopMember with member_to_row extension" begin + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], + maxsize=5, + popmember_type=TestCustomPopMember, + deterministic=true, + seed=0, + ) + + dataset = Dataset(X, y) + + # Create a custom member + tree = SymbolicRegression.create_expression(1.0f0, options, dataset) + custom_member = TestCustomPopMember( + dataset, tree, options; deterministic=true, custom_field=42.0 + ) + + # Test that member_to_row includes custom field + row = member_to_row(custom_member, dataset, options) + + @test haskey(row, :custom_field) + @test row.custom_field == 42.0 + @test haskey(row, :complexity) + @test haskey(row, :loss) + @test haskey(row, :equation) + + # Test with HOF + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.members[1] = custom_member + hof.exists[1] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + collected = collect(rows) + + @test length(collected) == 1 + @test haskey(collected[1], :custom_field) + @test collected[1].custom_field == 42.0 + end +end + +@testitem "Tables.jl extension" tags = [:part1] begin + using SymbolicRegression + using Test + + # Only run if Tables.jl is available + if isdefined(Base, :get_extension) + # Try to load Tables + try + @eval using Tables + + @testset "Tables.jl integration" begin + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], maxsize=5, deterministic=true, seed=0 + ) + + dataset = Dataset(X, y) + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + + # Test Tables.jl interface + @test Tables.istable(rows) + @test Tables.rowaccess(rows) + @test Tables.rows(rows) === rows # Should return itself + + # Test that it works with Tables.columntable + ct = Tables.columntable(rows) + @test ct isa NamedTuple + @test haskey(ct, :complexity) + @test haskey(ct, :loss) + end + catch e + @info "Skipping Tables.jl tests (Tables.jl not available): $e" + end + else + @info "Skipping Tables.jl tests (Julia version < 1.9)" + end +end From b5c0f7292ad3bdf10257dca30d3ca0e7b2b256ef Mon Sep 17 00:00:00 2001 From: Atharva Sehgal Date: Mon, 13 Oct 2025 08:04:59 +0000 Subject: [PATCH 15/15] add Column filtering --- ext/SymbolicRegressionTablesExt.jl | 36 +++- src/HallOfFame.jl | 268 +++++++++++++++++++++++++---- test/test_hof_rows.jl | 205 +++++++++++++++++++++- 3 files changed, 471 insertions(+), 38 deletions(-) diff --git a/ext/SymbolicRegressionTablesExt.jl b/ext/SymbolicRegressionTablesExt.jl index 75a69d366..0e07aed45 100644 --- a/ext/SymbolicRegressionTablesExt.jl +++ b/ext/SymbolicRegressionTablesExt.jl @@ -1,7 +1,7 @@ module SymbolicRegressionTablesExt using Tables: Tables -import SymbolicRegression.HallOfFameModule: HOFRows +import SymbolicRegression.HallOfFameModule: HOFRows, member_to_row # Make HOFRows compatible with the Tables.jl interface # HOFRows is already iterable via Base.iterate, so we just need to declare compatibility @@ -9,4 +9,38 @@ Tables.istable(::Type{<:HOFRows}) = true Tables.rowaccess(::Type{<:HOFRows}) = true Tables.rows(view::HOFRows) = view # Return itself since it's already iterable +# Provide schema information for better Tables.jl integration +function Tables.schema(rows::HOFRows) + if isempty(rows.members) + # Empty table - can't infer schema + return nothing + end + + # Get column names from either column specs or first row + if rows.columns !== nothing + # Use explicit column specs + names = Tuple(col.name for col in rows.columns) + # We can't reliably infer types without evaluating, so return nothing for types + return Tables.Schema(names, nothing) + else + # Infer from first row + first_row = member_to_row( + rows.members[1], rows.dataset, rows.options; pretty=rows.pretty + ) + if rows.include_score + # Will add score in iteration + names = (keys(first_row)..., :score) + else + names = keys(first_row) + end + # Get types from first row + types = if rows.include_score + (typeof.(values(first_row))..., Float64) # Assume Float64 for score + else + typeof.(values(first_row)) + end + return Tables.Schema(names, types) + end +end + end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 3844cfd90..a657f1d37 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -172,7 +172,7 @@ A NamedTuple containing the member's data. Default fields are: - `parent`: Parent reference ID - `equation`: Formatted equation string -# Example: Adding custom fields +# Example 1: Adding custom fields to a custom PopMember ```julia function SymbolicRegression.HallOfFameModule.member_to_row( member::MyCustomPopMember, @@ -185,6 +185,28 @@ function SymbolicRegression.HallOfFameModule.member_to_row( return merge(base, (my_field = member.custom_data,)) end ``` + +# Example 2: Displaying custom fields in the Hall of Fame +After extending `member_to_row`, create custom columns to display your fields: +```julia +using Printf + +custom_columns = [ + HOFColumn(:complexity, "C", row -> row.complexity, string, 5, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right), + HOFColumn(:my_field, "MyField", row -> row.my_field, x -> @sprintf("%.2f", x), 10, :right), + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left) +] + +# Display with custom columns +str = string_dominating_pareto_curve(hof, dataset, options; columns=custom_columns) +println(str) + +# Or export via Tables.jl with custom columns +rows = hof_rows(hof, dataset, options; columns=custom_columns) +using DataFrames +df = DataFrame(rows) +``` """ function member_to_row( member::AbstractPopMember, dataset::Dataset, options::AbstractOptions; pretty::Bool=true @@ -210,6 +232,92 @@ function member_to_row( ) end +""" + HOFColumn + +Specification for a column in Hall of Fame display and export. + +# Fields +- `name::Symbol`: Column identifier (key in the row NamedTuple) +- `header::String`: Display header text +- `getter::Function`: Function `(row::NamedTuple) -> value` to extract/compute column value +- `formatter::Function`: Function `(value) -> String` for display formatting (display only) +- `width::Union{Int,Nothing}`: Display width (nothing for auto-sizing) +- `alignment::Symbol`: Text alignment - `:left`, `:right`, or `:center` + +# Example +```julia +# Simple column that extracts an existing field +complexity_col = HOFColumn( + :complexity, "Complexity", + row -> row.complexity, + x -> string(x), + 10, :right +) + +# Computed column +r2_col = HOFColumn( + :r2, "R²", + row -> compute_r2(row), # Custom computation + x -> @sprintf("%.3f", x), + 8, :right +) +``` +""" +struct HOFColumn + name::Symbol + header::String + getter::Function + formatter::Function + width::Union{Int,Nothing} + alignment::Symbol +end + +""" + default_columns(options::AbstractOptions) -> Vector{HOFColumn} + +Return the default column specifications for Hall of Fame display. + +The default columns are: +- Complexity (right-aligned, width 10) +- Loss (right-aligned, width 9, scientific notation) +- Score (conditional on `options.loss_scale == :log`, right-aligned, width 9) +- Equation (left-aligned, auto-width) + +Users can customize by modifying this vector or creating their own. +""" +function default_columns(options::AbstractOptions) + cols = HOFColumn[ + HOFColumn( + :complexity, + "Complexity", + row -> row.complexity, + x -> @sprintf("%d", x), + 10, + :right, + ), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right), + ] + + # Add score column for logarithmic loss scale + if options.loss_scale == :log + push!( + cols, + HOFColumn( + :score, "Score", row -> row.score, x -> @sprintf("%.3e", x), 9, :right + ), + ) + end + + # Equation column (special handling in display due to wrapping) + push!( + cols, + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left), + ) + + return cols +end + """ HOFRows @@ -222,6 +330,7 @@ This struct implements the Tables.jl interface for easy export to DataFrames, CS - `options`: Options for complexity and formatting - `include_score`: Whether to compute and include Pareto improvement scores - `pretty`: Whether to use pretty-printing for equations +- `columns`: Optional column specifications (nothing = all columns from member_to_row) """ struct HOFRows{PM<:AbstractPopMember} members::Vector{PM} @@ -229,13 +338,26 @@ struct HOFRows{PM<:AbstractPopMember} options::AbstractOptions include_score::Bool pretty::Bool + columns::Union{Vector{HOFColumn},Nothing} end -# Helper function to create a single row with optional score +# Helper function to create a single row with optional score and column filtering @unstable function _make_row(view::HOFRows, i::Int, scores) + # Get full row from member_to_row row = member_to_row(view.members[i], view.dataset, view.options; pretty=view.pretty) - return scores === nothing ? row : (; row..., score=scores[i]) + # Add score if computed + row = scores === nothing ? row : (; row..., score=scores[i]) + + # Apply column filtering if specified + if view.columns !== nothing + # Build filtered row using column getters + filtered_values = [col.getter(row) for col in view.columns] + filtered_names = Tuple(col.name for col in view.columns) + return NamedTuple{filtered_names}(filtered_values) + end + + return row end # Make HOFRows iterable @@ -265,7 +387,7 @@ end """ hof_rows(hof::HallOfFame, dataset::Dataset, options::AbstractOptions; pareto_only::Bool=true, include_score::Bool=pareto_only, - pretty::Bool=true) + pretty::Bool=true, columns::Union{Vector{HOFColumn},Nothing}=nothing) This function returns an `HOFRows` object. @@ -276,6 +398,7 @@ This function returns an `HOFRows` object. - `pareto_only`: Only include Pareto frontier members (default: true) - `include_score`: Include Pareto improvement scores (default: same as `pareto_only`) - `pretty`: Use pretty-printing for equations (default: true) +- `columns`: Optional column specifications (default: nothing = all columns from member_to_row) # Returns An `HOFRows` object that can be used with Tables.jl-compatible consumers like @@ -292,6 +415,13 @@ df = DataFrame(rows) # Get all members without scores all_rows = hof_rows(hof, dataset, options; pareto_only=false, include_score=false) + +# Get only specific columns +custom_cols = [ + HOFColumn(:complexity, "Complexity", row -> row.complexity, string, 10, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right) +] +filtered_rows = hof_rows(hof, dataset, options; columns=custom_cols) ``` """ function hof_rows( @@ -301,6 +431,7 @@ function hof_rows( pareto_only::Bool=true, include_score::Bool=pareto_only, pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing, ) members = if pareto_only calculate_pareto_frontier(hof) @@ -308,58 +439,123 @@ function hof_rows( [m for (m, ex) in zip(hof.members, hof.exists) if ex] end - return HOFRows(members, dataset, options, include_score, pretty) + return HOFRows(members, dataset, options, include_score, pretty, columns) end -let header_parts = ( - rpad(styled"{bold:{underline:Complexity}}", 10), - rpad(styled"{bold:{underline:Loss}}", 9), - rpad(styled"{bold:{underline:Score}}", 9), - styled"{bold:{underline:Equation}}", +""" + string_dominating_pareto_curve( + hallOfFame, dataset, options; + width::Union{Integer,Nothing}=nothing, + pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing ) - @eval const HEADER = join($(header_parts), " ") - @eval const HEADER_WITHOUT_SCORE = join($(header_parts[[1, 2, 4]]), " ") -end -show_score_column(options::AbstractOptions) = options.loss_scale == :log +Format the Pareto frontier as a pretty-printed string table. + +# Arguments +- `hallOfFame`: The HallOfFame to display +- `dataset`: Dataset for formatting equations +- `options`: Options controlling complexity and formatting +- `width`: Terminal width (default: 100) +- `pretty`: Use pretty-printing for equations (default: true) +- `columns`: Column specifications (default: nothing = use default_columns(options)) +# Example with custom columns +```julia +custom_cols = [ + HOFColumn(:complexity, "C", row -> row.complexity, string, 5, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right), + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left) +] +str = string_dominating_pareto_curve(hof, dataset, options; columns=custom_cols) +``` +""" function string_dominating_pareto_curve( - hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing, pretty::Bool=true + hallOfFame, + dataset, + options; + width::Union{Integer,Nothing}=nothing, + pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing, ) + # Use default columns if not specified + cols = columns === nothing ? default_columns(options) : columns + terminal_width = (width === nothing) ? 100 : max(100, width::Integer) buffer = AnnotatedIOBuffer(IOBuffer()) + + # Print top border println(buffer, '─'^(terminal_width - 1)) - if show_score_column(options) - println(buffer, HEADER) - else - println(buffer, HEADER_WITHOUT_SCORE) + + # Build header from column specs + header_parts = map(cols) do col + header_text = styled"{bold:{underline:$(col.header)}}" + if col.width === nothing + # Last column (typically equation) - no padding + header_text + else + # Fixed-width column - pad to width + rpad(header_text, col.width) + end end + println(buffer, join(header_parts, " ")) - # Use hof_rows to get data with scores but without prefix - # (we need to format prefix specially for wrapping) + # Get rows (without column filtering, we'll format ourselves) rows_view = hof_rows( hallOfFame, dataset, options; pareto_only=true, include_score=true, pretty=pretty ) members = rows_view.members - for (i, row) in enumerate(rows_view) + # Format each row + for (i, full_row) in enumerate(rows_view) member = members[i] - prefix = make_prefix(member.tree, options, dataset) - eqn_string = row.equation - stats_columns_string = if show_score_column(options) - @sprintf("%-10d %-8.3e %-8.3e ", row.complexity, row.loss, row.score) - else - @sprintf("%-10d %-8.3e ", row.complexity, row.loss) + + # Format all columns except the last one (which may need wrapping) + formatted_cols = String[] + for (col_idx, col) in enumerate(cols) + value = col.getter(full_row) + formatted = col.formatter(value) + + if col_idx == length(cols) + # Last column - handle separately for wrapping + # Calculate left margin from previous columns + left_cols_width = sum( + length(formatted_cols[j]) + 2 for j in 1:(length(formatted_cols)) + ) + + # Handle equation prefix if it's an equation column + if col.name == :equation && haskey(full_row, :equation) + prefix = make_prefix(member.tree, options, dataset) + wrapped = wrap_equation_string( + formatted, left_cols_width + length(prefix), terminal_width + ) + print(buffer, join(formatted_cols, " ")) + print(buffer, " ") + print(buffer, wrapped) + else + # Non-equation last column - just print + push!(formatted_cols, formatted) + println(buffer, join(formatted_cols, " ")) + end + else + # Non-last column - format with alignment and width + if col.width !== nothing + if col.alignment == :right + formatted = lpad(formatted, col.width) + elseif col.alignment == :center + formatted = lpad( + rpad(formatted, (col.width + length(formatted)) ÷ 2), col.width + ) + else # :left + formatted = rpad(formatted, col.width) + end + end + push!(formatted_cols, formatted) + end end - left_cols_width = length(stats_columns_string) - print(buffer, stats_columns_string) - print( - buffer, - wrap_equation_string( - eqn_string, left_cols_width + length(prefix), terminal_width - ), - ) end + + # Print bottom border print(buffer, '─'^(terminal_width - 1)) return dump_buffer(buffer) end diff --git a/test/test_hof_rows.jl b/test/test_hof_rows.jl index d8239850f..5a038b31c 100644 --- a/test/test_hof_rows.jl +++ b/test/test_hof_rows.jl @@ -237,7 +237,7 @@ end end function create_child( - parent::CustomPopMember{T,L}, + parent::TestCustomPopMember{T,L}, tree::AbstractExpression{T}, cost::L, loss::L, @@ -367,3 +367,206 @@ end @info "Skipping Tables.jl tests (Julia version < 1.9)" end end + +@testitem "Column specifications" tags = [:part1] begin + using SymbolicRegression + using Test + using Printf: @sprintf + + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; binary_operators=[+, -], maxsize=5, deterministic=true, seed=0) + + dataset = Dataset(X, y) + + @testset "HOFColumn basics" begin + # Create a simple column + col = SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "Loss", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ) + + @test col.name == :loss + @test col.header == "Loss" + @test col.width == 8 + @test col.alignment == :right + + # Test getter and formatter + test_row = (loss=0.123456, complexity=5) + @test col.getter(test_row) == 0.123456 + @test col.formatter(0.123456) == "1.23e-01" + end + + @testset "default_columns" begin + # Test default columns without score (linear loss scale) + options_linear = Options(; + binary_operators=[+, -], maxsize=5, loss_scale=:linear, deterministic=true + ) + cols_linear = SymbolicRegression.HallOfFameModule.default_columns(options_linear) + + @test length(cols_linear) == 3 # complexity, loss, equation + @test cols_linear[1].name == :complexity + @test cols_linear[2].name == :loss + @test cols_linear[3].name == :equation + + # Test default columns with score (log loss scale) + options_log = Options(; + binary_operators=[+, -], maxsize=5, loss_scale=:log, deterministic=true + ) + cols_log = SymbolicRegression.HallOfFameModule.default_columns(options_log) + + @test length(cols_log) == 4 # complexity, loss, score, equation + @test cols_log[1].name == :complexity + @test cols_log[2].name == :loss + @test cols_log[3].name == :score + @test cols_log[4].name == :equation + end + + @testset "Custom columns with HOFRows" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + # Create custom column specs + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + ] + + # Get rows with custom columns + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, columns=custom_cols + ) + + # Collect and verify + collected = collect(rows) + @test length(collected) == 2 + + # Should only have the two specified columns + for row in collected + @test haskey(row, :complexity) + @test haskey(row, :loss) + @test !haskey(row, :equation) # Not requested + @test !haskey(row, :cost) # Not requested + end + end + + @testset "string_dominating_pareto_curve with custom columns" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + # Test with default columns + str_default = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options + ) + @test str_default isa AbstractString + @test contains(str_default, "Complexity") + @test contains(str_default, "Loss") + + # Test with custom columns + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :equation, "Eq", row -> row.equation, identity, nothing, :left + ), + ] + + str_custom = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options; columns=custom_cols + ) + @test str_custom isa AbstractString + @test contains(str_custom, "C") # Custom header + @test contains(str_custom, "L") # Custom header + @test !contains(str_custom, "Complexity") # Original header should not appear + end + + @testset "Computed columns" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + # Create a computed column (e.g., cost/loss ratio) + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :ratio, + "Cost/Loss", + row -> row.cost / row.loss, # Computed from multiple fields + x -> @sprintf("%.2f", x), + 10, + :right, + ), + ] + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, columns=custom_cols + ) + + collected = collect(rows) + @test length(collected) == 1 + @test haskey(collected[1], :ratio) + @test collected[1].ratio isa Number + end +end + +@testitem "Column specs with Tables.jl" tags = [:part1] begin + using SymbolicRegression + using Test + using Printf: @sprintf + + # Only run if Tables.jl is available + if isdefined(Base, :get_extension) + try + @eval using Tables + + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], maxsize=5, deterministic=true, seed=0 + ) + + dataset = Dataset(X, y) + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + @testset "Tables.jl with custom columns" begin + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + ] + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; columns=custom_cols + ) + + # Test schema + schema = Tables.schema(rows) + @test schema !== nothing + @test schema.names == (:complexity, :loss) + + # Test columntable + ct = Tables.columntable(rows) + @test haskey(ct, :complexity) + @test haskey(ct, :loss) + @test !haskey(ct, :equation) # Not in custom columns + end + catch e + @info "Skipping Tables.jl column spec tests: $e" + end + end +end