Skip to content

Commit

Permalink
fix Num typepiracy
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye committed Oct 13, 2024
1 parent 577138a commit a716f58
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 43 deletions.
39 changes: 20 additions & 19 deletions src/Symbolics/Symbolics_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,15 @@ function simplify_complex(x::BasicSymbolic)
end
end

# sometimes, expressions get stored as Complex{Num} with no way to decode what real(x) and imag(x)
# this overloads the Num constructor to return a Num if x.re and x.im have similar arguments
function Num(x::Complex{Num})::Num
if x.re.val isa Float64 && x.im.val isa Float64
return Num(x.re.val)
else
if isequal(x.re.val.arguments, x.im.val.arguments)
Num(first(x.re.val.arguments))
else
error("Cannot convert Complex{Num} " * string(x) * " to Num")
end
end
end
# ^ This function commits type-piracy with Symbolics.jl. We should change this.

"""
$(TYPEDSIGNATURES)
Perform substitutions in `rules` on `x`.
`include_derivatives=true` also includes all derivatives of the variables of the keys of `rules`.
"""
function substitute_all(
x::T, rules::Dict; include_derivatives=true
)::T where {T<:Union{Equation,Num}}
subtype=Union{Num,Equation,BasicSymbolic}
function substitute_all(x::subtype, rules::Dict; include_derivatives=true)
if include_derivatives
rules = merge(
rules,
Expand All @@ -71,7 +56,24 @@ function substitute_all(dict::Dict, rules::Dict)::Dict
end
Collections = Union{Dict,Pair,Vector,OrderedDict}
substitute_all(v::AbstractArray, rules) = [substitute_all(x, rules) for x in v]
substitute_all(x::Union{Num,Equation}, rules::Collections) = substitute_all(x, Dict(rules))
substitute_all(x::subtype, rules::Collections) = substitute_all(x, Dict(rules))
# Collections = Union{Dict,OrderedDict}
# function substitute_all(x, rules::Collections; include_derivatives=true)
# if include_derivatives
# rules = merge(
# rules,
# Dict([Differential(var) => Differential(rules[var]) for var in keys(rules)]),
# )
# end
# return substitute(x, rules)
# end
# "Variable substitution - dictionary"
# function substitute_all(dict::Dict, rules::Dict)::Dict
# new_keys = substitute_all.(keys(dict), rules)
# new_values = substitute_all.(values(dict), rules)
# return Dict(zip(new_keys, new_values))
# end
# substitute_all(v::AbstractArray, rules::Collections) = [substitute_all(x, rules) for x in v]


get_independent(x::Num, t::Num) = get_independent(x.val, t)
Expand Down Expand Up @@ -125,6 +127,5 @@ end
is_harmonic(x::Equation, t::Num) = is_harmonic(x.lhs, t) && is_harmonic(x.rhs, t)
is_harmonic(x, t) = is_harmonic(Num(x), Num(t))


"Return true if `f` is a function of `var`."
is_function(f, var) = any(isequal.(get_variables(f), var))
1 change: 0 additions & 1 deletion src/Symbolics/drop_powers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int)
removal = Dict([ϵ^d => Num(0) for d in deg:max_deg])
res = substitute_all(substitute_all(subs_expr, removal), Dict=> Num(1)))
return Symbolics.expand(res)
#res isa Complex ? Num(res.re.val.arguments[1]) : res
end

function drop_powers(expr::Vector{Num}, var::Vector{Num}, deg::Int)
Expand Down
6 changes: 3 additions & 3 deletions src/Symbolics/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ function trig_to_exp(x::Num)
end

result = Symbolics.substitute(x, Dict(rules))
#result = result isa Complex ? Num(first(result.re.val.arguments)) : result
result = Num(result)
return result
return convert_to_Num(result)
end
convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments))
convert_to_Num(x::Num)::Num = x

function exp_to_trig(x::BasicSymbolic)
if isadd(x) || isdiv(x) || ismul(x)
Expand Down
4 changes: 2 additions & 2 deletions src/transform_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ function to_lab_frame(soln, res, times)::Vector{AbstractFloat}
timetrace = zeros(length(times))

for var in res.problem.eom.variables
val = unwrap(substitute_all(_remove_brackets(var), soln))
ω = unwrap(substitute_all(var.ω, soln))
val = real(substitute_all(unwrap(_remove_brackets(var)), soln))
ω = real(unwrap(substitute_all(var.ω, soln)))
if var.type == "u"
timetrace .+= val * cos.(ω * times)
elseif var.type == "v"
Expand Down
35 changes: 17 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@ using Random
const SEED = 0xd8e5d8df
Random.seed!(SEED)

# @testset "Code quality" begin
# using ExplicitImports, Aqua
# ignore_deps = [:Random, :LinearAlgebra, :Printf, :Test, :Pkg]
@testset "Code quality" begin
using ExplicitImports, Aqua
ignore_deps = [:Random, :LinearAlgebra, :Printf, :Test, :Pkg]

# @test check_no_stale_explicit_imports(HarmonicBalance) == nothing
# @test check_all_explicit_imports_via_owners(HarmonicBalance) == nothing
# Aqua.test_ambiguities(HarmonicBalance)
# Aqua.test_all(
# HarmonicBalance;
# deps_compat=(
# ignore=ignore_deps,
# check_extras=(ignore=ignore_deps,),
# check_weakdeps=(ignore=ignore_deps,),
# ),
# piracies=(treat_as_own=[HarmonicBalance.Num],),
# ambiguities=false,
# )
# end
@test check_no_stale_explicit_imports(HarmonicBalance) == nothing
@test check_all_explicit_imports_via_owners(HarmonicBalance) == nothing
Aqua.test_ambiguities(HarmonicBalance)
Aqua.test_all(
HarmonicBalance;
deps_compat=(
ignore=ignore_deps,
check_extras=(ignore=ignore_deps,),
check_weakdeps=(ignore=ignore_deps,),
),
ambiguities=false,
)
end

@testset "Code linting" begin
using JET
Expand All @@ -34,7 +33,7 @@ end
end

@testset "Symbolics customised" begin
include("Symbolics.jl")
include("symbolics.jl")
end

@testset "IO" begin
Expand Down

0 comments on commit a716f58

Please sign in to comment.