Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up symbolics #267

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/HarmonicBalance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ export get_krylov_equations
include("modules/FFTWExt.jl")
using .FFTWExt

@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# precompile file and potentially make loading faster.
@compile_workload begin
# all calls in this block will be precompiled, regardless of whether
# they belong to your package or not (on Julia 1.8 and higher)
include("precompilation.jl")
end
end
# @setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# @compile_workload begin
# # all calls in this block will be precompiled, regardless of whether
# # they belong to your package or not (on Julia 1.8 and higher)
# include("precompilation.jl")
# end
# end

end # module
219 changes: 90 additions & 129 deletions src/Symbolics_customised.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Symbolics
using SymbolicUtils:
SymbolicUtils,
Postwalk,
Expand All @@ -8,14 +9,19 @@ using SymbolicUtils:
isadd,
isdiv,
ismul,
issym,
add_with_div,
frac_maketerm, #, @compactified
issym

frac_maketerm,
@rule,
@acrule,
@compactified,
isliteral,
Chain
using Symbolics:
Symbolics,
Num,
unwrap,
wrap,
get_variables,
simplify,
expand_derivatives,
Expand All @@ -28,149 +34,104 @@ using Symbolics:
substitute,
term,
expand,
operation

"Returns true if expr is an exponential"
is_exp(expr) = isterm(expr) && expr.f == exp

"Expand powers of exponential such that exp(x)^n => exp(x*n) "
expand_exp_power(expr) =
ispow(expr) && is_exp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr
expand_exp_power_add(expr) = sum([expand_exp_power(arg) for arg in arguments(expr)])
expand_exp_power_mul(expr) = prod([expand_exp_power(arg) for arg in arguments(expr)])
expand_exp_power(expr::Num) = expand_exp_power(expr.val)
# operation,
_iszero

function expand_exp_power(expr::BasicSymbolic)
if isadd(expr)
return expand_exp_power_add(expr)
elseif ismul(expr)
return expand_exp_power_mul(expr)
else
return if ispow(expr) && is_exp(expr.base)
exp(expr.base.arguments[1] * expr.exp)
else
expr
end
end
# ∨ Complex{Num} can contain
is_false_or_zero(x) = x === false || _iszero(x)
is_literal_complex(x) = isliteral(Complex)(x)
is_literal_real(x) = isliteral(Real)(x)
function is_not_complex(x)
return !is_literal_real(x) && (is_literal_complex(x) && is_false_or_zero(unwrap(x.im)))
end

"Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)"
expand_all(x) = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im)
expand_all(x::Num) = Num(expand_all(x.val))
const expand_exp_power = @rule(exp(~x)^(~y) => exp(~x * ~y))
const simplify_exp_mul = @acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y))
const sin_euler = @rule(sin(~x) => (exp(im * ~x) - exp(-im * ~x)) / (2 * im))
const cos_euler = @rule(cos(~x) => (exp(im * ~x) + exp(-im * ~x)) / 2)
const not_complex = @rule(~x::is_not_complex => real(~x))

"Apply a function f on every member of a sum or a product"
_apply_termwise(f, x) = f(x)
function _apply_termwise(f, x::BasicSymbolic)
if isadd(x)
return sum([f(arg) for arg in arguments(x)])
elseif ismul(x)
return prod([f(arg) for arg in arguments(x)])
elseif isdiv(x)
return _apply_termwise(f, x.num) / _apply_termwise(f, x.den)
else
return f(x)
end
end
# We could use @compactified to do the achive thing wit a speed-up. Neverthless, it yields less readable code.
# @compactified is what SymbolicUtils uses internally
# function _apply_termwise(f, x::BasicSymbolic)
# @compactified x::BasicSymbolic begin
# Add => sum([f(arg) for arg in arguments(x)])
# Mul => prod([f(arg) for arg in arguments(x)])
# Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den)
# _ => f(x)
# end
# end
"""
Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)
"""
expand_all(x) = simplify(expand(x); rewriter=Postwalk(expand_exp_power))
expand_all(x::Num) = wrap(expand_all(unwrap(x)))
function expand_all(x::Complex{Num})
re_val = is_false_or_zero(unwrap(x.re)) ? 0.0 : expand_all(x.re)
im_val = is_false_or_zero(unwrap(x.im)) ? 0.0 : expand_all(x.im)
return re_val + im * im_val
end # This code is stupid, we can just use simplify

simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im
simplify_complex(x) = x
function simplify_complex(x::BasicSymbolic)
if isadd(x) || ismul(x) || isdiv(x)
return _apply_termwise(simplify_complex, x)
else
return x
end
end
# simplify_complex(x) = simplify(expand(x); rewriter=Postwalk(not_complex))
simplify_complex(x) = Postwalk(@rule(~x::is_not_complex => real(~x)))(x)

"Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b)
This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls"
function simplify_exp_products_mul(expr)
ind = findall(x -> is_exp(x), arguments(expr))
rest_ind = setdiff(1:length(arguments(expr)), ind)
rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind])
total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1))
if SymbolicUtils.is_literal_number(total)
(total == 0 && return rest)
else
return rest * exp(total)
end
"""
Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b)"
"""
simplify_exp_products(x) = Postwalk(simplify_exp_mul)(x)
function simplify_exp_products(x::Complex{Num})
re_val = is_false_or_zero(unwrap(x.re)) ? 0.0 : simplify_exp_products(x.re)
im_val = is_false_or_zero(unwrap(x.im)) ? 0.0 : simplify_exp_products(x.im)
return re_val + im * im_val
end

function simplify_exp_products(x::Complex{Num})
return Complex{Num}(simplify_exp_products(x.re.val), simplify_exp_products(x.im.val))
"""
Converts the trigonometric functions to exponentials using Euler's formulas.
"""
trig_to_exp(x) = simplify(x; rewriter=Postwalk(Chain([sin_euler, cos_euler])))
trig_to_exp(x::Num) = wrap(trig_to_exp(unwrap(x)))
function trig_to_exp(x::Complex{Num})
re_val = is_false_or_zero(unwrap(x.re)) ? 0.0 : trig_to_exp(x.re)
im_val = is_false_or_zero(unwrap(x.im)) ? 0.0 : trig_to_exp(x.im)
return re_val + im * im_val
end
simplify_exp_products(x::Num) = simplify_exp_products(x.val)
simplify_exp_products(x) = x

function simplify_exp_products(expr::BasicSymbolic)
if isadd(expr) || isdiv(expr)
return _apply_termwise(simplify_exp_products, expr)
elseif ismul(expr)
return simplify_exp_products_mul(expr)
else
return expr
end
"""
Reparse the symbolic expression.
Symbolics.jl applies some simplifications when doing this
"""
function reparse(x)
str = string(x)
str′ = replace(str, ")(" => ")*(")
parse_expr_to_symbolic(Meta.parse(str′), @__MODULE__)
end
# ^ parsing and reevaluting makes it that (1 - 0.0im) becomes (1)

function exp_to_trig(x::BasicSymbolic)
if isadd(x) || isdiv(x) || ismul(x)
return _apply_termwise(exp_to_trig, x)
elseif isterm(x) && x.f == exp
arg = first(x.arguments)
trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function
trigarg = simplify_complex(trigarg)
"""
Converts the sinusoidal function to the the cananical form, i.e.,
sin(x) => -sin(-x) or cos(-x) => cos(x)
"""
function make_positive_trig(x::Num)
all_terms = get_all_terms(x)
trigs = filter(z -> is_trig(z), all_terms)

# put arguments of trigs into a standard form such that sin(x) = -sin(-x), cos(x) = cos(-x) are recognized
if isadd(trigarg)
first_symbol = minimum(
cat(string.(arguments(trigarg)), string.(arguments(-trigarg)); dims=1)
)
rules = []
for trig in trigs
is_pow = ispow(trig.val) # trig is either a trig or a power of a trig
power = is_pow ? trig.val.exp : 1
arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1]
type = is_pow ? operation(trig.val.base) : operation(trig.val)
negative =
!issym(arg) && prod(Number.(filter(x -> x isa Number, arguments(arg)))) < 0

# put trigarg => -trigarg the lowest alphabetic argument of trigarg is lower than that of -trigarg
# this is a meaningless key but gives unique signs to all sums
is_first = minimum(string.(arguments(trigarg))) == first_symbol
return if is_first
cos(-trigarg) - im * sin(-trigarg)
else
cos(trigarg) + im * sin(trigarg)
if negative
if type == cos
term = cos(-arg)
elseif type == sin
term = (-1)^power * sin(-arg)^power
end
append!(rules, [trig => term])
end
return if ismul(trigarg) && trigarg.coeff < 0
cos(-trigarg) - im * sin(-trigarg)
else
cos(trigarg) + im * sin(trigarg)
end
else
return x
end
result = Symbolics.substitute(x, Dict(rules))
return result
end

exp_to_trig(x) = x
exp_to_trig(x::Num) = exp_to_trig(x.val)
exp_to_trig(x::Complex{Num}) = exp_to_trig(x.re) + im * exp_to_trig(x.im)

# sometimes, expressions get stored as Complex{Num} with no way to decode what real(x) and imag(x)
# this overloads the Num constructor to return a Num if x.re and x.im have similar arguments
function Num(x::Complex{Num})::Num
if x.re.val isa Float64 && x.im.val isa Float64
return Num(x.re.val)
else
if isequal(x.re.val.arguments, x.im.val.arguments)
Num(first(x.re.val.arguments))
else
error("Cannot convert Complex{Num} " * string(x) * " to Num")
end
end
"""
Converts all the exponentials to trigonometric functions by reparsing. Symbolics does this
somewhere internally, so we just reparse the expression. This is a workaround.
"""
function exp_to_trig(z::Complex{Num})
z = reparse(z)
return simplify_complex(make_positive_trig(z.re) + im * make_positive_trig(z.im))
end
# ^ This function commits type-piracy with Symbolics.jl. We should change this.
60 changes: 30 additions & 30 deletions src/Symbolics_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,36 +185,36 @@ end
is_harmonic(x::Equation, t::Num) = is_harmonic(x.lhs, t) && is_harmonic(x.rhs, t)
is_harmonic(x, t) = is_harmonic(Num(x), Num(t))

"Convert all sin/cos terms in `x` into exponentials."
function trig_to_exp(x::Num)
all_terms = get_all_terms(x)
trigs = filter(z -> is_trig(z), all_terms)

rules = []
for trig in trigs
is_pow = ispow(trig.val) # trig is either a trig or a power of a trig
power = is_pow ? trig.val.exp : 1
arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1]
type = is_pow ? operation(trig.val.base) : operation(trig.val)

if type == cos
term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0)
elseif type == sin
term =
(1 * im^power) *
Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0)
end
# avoid Complex{Num} where possible as this causes bugs
# instead, the Nums store SymbolicUtils Complex types
term = Num(Symbolics.expand(term.re.val + im * term.im.val))
append!(rules, [trig => term])
end

result = Symbolics.substitute(x, Dict(rules))
#result = result isa Complex ? Num(first(result.re.val.arguments)) : result
result = Num(result)
return result
end
# "Convert all sin/cos terms in `x` into exponentials."
# function trig_to_exp(x::Num)
# all_terms = get_all_terms(x)
# trigs = filter(z -> is_trig(z), all_terms)

# rules = []
# for trig in trigs
# is_pow = ispow(trig.val) # trig is either a trig or a power of a trig
# power = is_pow ? trig.val.exp : 1
# arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1]
# type = is_pow ? operation(trig.val.base) : operation(trig.val)

# if type == cos
# term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0)
# elseif type == sin
# term =
# (1 * im^power) *
# Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0)
# end
# # avoid Complex{Num} where possible as this causes bugs
# # instead, the Nums store SymbolicUtils Complex types
# term = Num(Symbolics.expand(term.re.val + im * term.im.val))
# append!(rules, [trig => term])
# end

# result = Symbolics.substitute(x, Dict(rules))
# #result = result isa Complex ? Num(first(result.re.val.arguments)) : result
# result = Num(result)
# return result
# end

"Return true if `f` is a function of `var`."
is_function(f, var) = any(isequal.(get_variables(f), var))
Expand Down
Loading
Loading