diff --git a/src/classification.jl b/src/classification.jl index 994ab0cd..1a21d0c2 100644 --- a/src/classification.jl +++ b/src/classification.jl @@ -20,20 +20,20 @@ classify_solutions!(res, "sqrt(u1^2 + v1^2) > 1.0" , "large_amplitude") ``` """ -function classify_solutions!(res::Result, func::Union{String, Function}, name::String; physical=true) - values = classify_solutions(res, func; physical=physical) +function classify_solutions!(res::Result, func::Union{String, Function}, name::String; physical=true, kwargs...) + values = classify_solutions(res, func; physical=physical, kwargs...) res.classes[name] = values end -function classify_solutions(res::Result, func; physical=true) +function classify_solutions(res::Result, func; physical=true, kwargs...) func = isa(func, Function) ? func : _build_substituted(func, res) if physical - f_comp(soln) = _is_physical(soln) && func(real.(soln)) - transform_solutions(res, f_comp) + f_comp(soln; kwargs...) = _is_physical(soln) && func(real.(soln); kwargs...) + transform_solutions(res, f_comp; kwargs...) else - transform_solutions(res, func) + transform_solutions(res, func; kwargs...) end end diff --git a/src/transform_solutions.jl b/src/transform_solutions.jl index 9362a319..7e219930 100644 --- a/src/transform_solutions.jl +++ b/src/transform_solutions.jl @@ -11,14 +11,14 @@ Takes a `Result` object and a string `f` representing a Symbolics.jl expression. Returns an array with the values of `f` evaluated for the respective solutions. Additional substitution rules can be specified in `rules` in the format `("a" => val)` or `(a => val)` """ -function transform_solutions(res::Result, func; branches = 1:branch_count(res)) +function transform_solutions(res::Result, func; branches = 1:branch_count(res), kwargs...) # preallocate an array for the numerical values, rewrite parts of it # when looping through the solutions pars = res.swept_parameters |> values |> collect n_vars = length(get_variables(res)) n_pars = length(pars) - vtype = isa(Base.invokelatest(func, rand(ComplexF64, n_vars+n_pars)), Bool) ? BitVector : Vector{ComplexF64} + vtype = isa(Base.invokelatest(func, rand(ComplexF64, n_vars+n_pars); kwargs...), Bool) ? BitVector : Vector{ComplexF64} transformed = _similar(vtype, res; branches=branches) batches = Iterators.partition(CartesianIndices(res.solutions), ceil(Int, length(res.solutions)/Threads.nthreads())) @@ -30,7 +30,7 @@ function transform_solutions(res::Result, func; branches = 1:branch_count(res)) end for (k, branch) in enumerate(branches) _vals[1:n_vars] .= res.solutions[idx][branch] - transformed[idx][k] = Base.invokelatest(func, _vals) # beware, func may be mutating + transformed[idx][k] = Base.invokelatest(func, _vals; kwargs...) # beware, func may be mutating end end end