diff --git "a/docs/src/tutorials/Schr\303\266dingerEquation.md" "b/docs/src/tutorials/Schr\303\266dingerEquation.md" index ac07440a..96df9aaf 100644 --- "a/docs/src/tutorials/Schr\303\266dingerEquation.md" +++ "b/docs/src/tutorials/Schr\303\266dingerEquation.md" @@ -45,7 +45,7 @@ function train(pde_system, prob, sampler, strategy, resample_period = 500, n=10) res = Optimization.solve(prob, bfgs; maxiters=2000) for i in 1:n - data = Sophon.sample(pde_system, sampler, strategy) + data = Sophon.sample(pde_system, sampler) prob = remake(prob; u0=res.u, p=data) res = Optimization.solve(prob, bfgs; maxiters=resample_period) end diff --git a/docs/src/tutorials/allen_cahn.md b/docs/src/tutorials/allen_cahn.md index d922ad40..118f1c54 100644 --- a/docs/src/tutorials/allen_cahn.md +++ b/docs/src/tutorials/allen_cahn.md @@ -42,7 +42,7 @@ function train(allen, prob, sampler, strategy) for tmax in [0.5, 0.75, 1.0] allen.domain[2] = t ∈ 0.0..tmax - data = Sophon.sample(allen, sampler, strategy) + data = Sophon.sample(allen, sampler) prob = remake(prob; u0=res.u, p=data) res = Optimization.solve(prob, bfgs; maxiters=2000) end diff --git a/docs/src/tutorials/burgers.jl b/docs/src/tutorials/burgers.jl index 0d1620cc..b1befca2 100644 --- a/docs/src/tutorials/burgers.jl +++ b/docs/src/tutorials/burgers.jl @@ -1,7 +1,8 @@ using Sophon, ModelingToolkit using DomainSets using DomainSets: × -using Optimization, OptimizationOptimJL, Optimisers +using Optimization, OptimizationOptimJL +import OptimizationFlux: Adam using Interpolations, GaussianRandomFields using Setfield @@ -12,32 +13,34 @@ Dₜ = Differential(t) Dₓ² = Dₓ^2 ν = 0.001 -eq = Dₜ(u(x,t)) + u(x,t) * Dₓ(u(x,t)) ~ ν * Dₓ²(u(x,t)) +eq = Dₜ(u(x, t)) + u(x, t) * Dₓ(u(x, t)) ~ ν * Dₓ²(u(x, t)) domain = (0.0 .. 1.0) × (0.0 .. 1.0) eq = eq => domain -bcs = [(u(0.0, t) ~ u(1.0, t)) => (0.0 .. 1.0), - (u(x, t) ~ a(x)) => (0.0 .. 1.0) × (0.0 .. 0.0)] +bcs = [ + (u(0.0, t) ~ u(1.0, t)) => (0.0 .. 1.0), + (u(x, t) ~ a(x)) => (0.0 .. 1.0) × (0.0 .. 0.0), +] boundary = 0.0 .. 1.0 -Burgers = Sophon.ParametricPDESystem([eq], bcs, [t, x], [u(x,t)], [a(x)]) +Burgers = Sophon.ParametricPDESystem([eq], bcs, [t, x], [u(x, t)], [a(x)]) chain = DeepONet((50, 50, 50, 50), tanh, (2, 50, 50, 50, 50), tanh) pinn = PINN(chain) sampler = QuasiRandomSampler(500, 50) strategy = NonAdaptiveTraining() -struct MyFuncSampler <: Sophon.FunctionSampler - pts - grf - n +struct MyFuncSampler <: Sophon.FunctionSampler + pts::Any + grf::Any + n::Any end function MyFuncSampler(pts, n) - cov = CovarianceFunction(1, Whittle(.1)) - grf = GaussianRandomField(cov, KarhunenLoeve(5), pts) - return MyFuncSampler(pts, grf, n) + cov = CovarianceFunction(1, Whittle(0.1)) + grf = GaussianRandomField(cov, KarhunenLoeve(5), pts) + return MyFuncSampler(pts, grf, n) end function Sophon.sample(sampler::MyFuncSampler) @@ -48,15 +51,15 @@ function Sophon.sample(sampler::MyFuncSampler) push!(ys, y) end return ys -end +end + +cord_branch_net = range(0.0, 1.0; length=50) -cord_branch_net = range(0.0, 1.0, length=50) +fsampler = MyFuncSampler(cord_branch_net, 10) -fsampler = MyFuncSampler(cord_branch_net, 10) - prob = Sophon.discretize(Burgers, pinn, sampler, strategy, fsampler, cord_branch_net) -function callback(p,l) +function callback(p, l) println("Loss: $l") return false end @@ -64,25 +67,25 @@ end @time res = Optimization.solve(prob, BFGS(); maxiters=1000, callback=callback) using ProgressMeter - -n = 20000 + +n = 20000 k = 10 pg = Progress(n; showspeed=true) - -function callback(p,l) - ProgressMeter.next!(pg; showvalues = [(:loss, l)]) + +function callback(p, l) + ProgressMeter.next!(pg; showvalues=[(:loss, l)]) return false end - -adam = Adam() + +adam = Adam() for i in 1:k cords = Sophon.sample(Burgers, sampler, strategy) fs = Sophon.sample(fsampler) p = Sophon.PINOParameterHandler(cords, fs) - prob = remake(prob; u0 = res.u, p = p) - res = Optimization.solve(prob, adam; maxiters= n ÷ k, callback=callback) + prob = remake(prob; u0=res.u, p=p) + res = Optimization.solve(prob, adam; maxiters=n ÷ k, callback=callback) end - + using CairoMakie phi = pinn.phi @@ -93,4 +96,4 @@ f_test(x) = sinpi(2x) u0 = reshape(f_test.(cord_branch_net), :, 1) axis = (xlabel="t", ylabel="x", title="Prediction") u_pred = [sum(pinn.phi((u0, [x, t]), res.u)) for x in xs, t in ts] -fig, ax, hm = heatmap(ts, xs, u_pred', axis=axis, colormap=:jet) +fig, ax, hm = heatmap(ts, xs, u_pred'; axis=axis, colormap=:jet) diff --git a/docs/src/tutorials/sod.jl b/docs/src/tutorials/sod.jl index 8431fb23..86415a20 100644 --- a/docs/src/tutorials/sod.jl +++ b/docs/src/tutorials/sod.jl @@ -33,7 +33,7 @@ domains = [t ∈ Interval(t_min, t_max), x ∈ Interval(x_min, x_max)] pinn = PINN(; u=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16), ρ=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16), - p=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16)) + p=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16)) sampler = QuasiRandomSampler(1000, 100) function pde_weights(phi, x, θ) diff --git a/src/compact/NeuralPDE/discretize.jl b/src/compact/NeuralPDE/discretize.jl index 597472c7..e62341d0 100644 --- a/src/compact/NeuralPDE/discretize.jl +++ b/src/compact/NeuralPDE/discretize.jl @@ -1,8 +1,6 @@ -function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn::PINN, - strategy::AbstractTrainingAlg; - additional_loss=Sophon.null_additional_loss, - derivative=finitediff) +function build_loss_function(pde_system::NeuralPDE.PDESystem, pinn::PINN, + strategy::AbstractTrainingAlg; derivative=finitediff) (; eqs, bcs, domain, ps, defaults, indvars, depvars) = pde_system (; phi, init_params) = pinn @@ -26,11 +24,10 @@ function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn:: bc_integration_vars = NeuralPDE.get_integration_variables(bcs, dict_indvars, dict_depvars) - pinnrep = (; eqs, bcs, domain, ps, defaults, default_p, additional_loss, depvars, - indvars, dict_indvars, dict_depvars, dict_depvar_input, multioutput, - init_params, phi, derivative, strategy, pde_indvars, bc_indvars, - pde_integration_vars, bc_integration_vars, fdtype=Float64, - eq_params=SciMLBase.NullParameters()) + pinnrep = (; eqs, bcs, domain, ps, defaults, default_p, depvars, indvars, dict_indvars, + dict_depvars, dict_depvar_input, multioutput, init_params, phi, derivative, + strategy, pde_indvars, bc_indvars, pde_integration_vars, bc_integration_vars, + fdtype=Float64, eq_params=SciMLBase.NullParameters()) integral = get_numeric_integral(pinnrep) pinnrep = merge(pinnrep, (; integral)) @@ -46,17 +43,11 @@ function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn:: pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions, datafree_bc_loss_functions) - - function full_loss_function(θ, p) - return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ) - end - return full_loss_function + return pde_and_bcs_loss_function end -function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN, - strategy::AbstractTrainingAlg; - additional_loss=Sophon.null_additional_loss, - derivative=finitediff) +function build_loss_function(pde_system::PDESystem, pinn::PINN, + strategy::AbstractTrainingAlg; derivative=finitediff) (; eqs, bcs, ivs, dvs) = pde_system (; phi, init_params) = pinn @@ -74,10 +65,10 @@ function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN, bc_integration_vars = NeuralPDE.get_integration_variables(map(first, bcs), dict_indvars, dict_depvars) - pinnrep = (; eqs, bcs, additional_loss, depvars, indvars, dict_indvars, dict_depvars, - dict_depvar_input, multioutput, init_params, phi, derivative, strategy, - pde_indvars, bc_indvars, pde_integration_vars, bc_integration_vars, - fdtype=Float64, eq_params=SciMLBase.NullParameters()) + pinnrep = (; eqs, bcs, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input, + multioutput, init_params, phi, derivative, strategy, pde_indvars, bc_indvars, + pde_integration_vars, bc_integration_vars, fdtype=Float64, + eq_params=SciMLBase.NullParameters()) integral = nothing pinnrep = merge(pinnrep, (; integral)) @@ -93,17 +84,12 @@ function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN, pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions, datafree_bc_loss_functions) - - function full_loss_function(θ, p) - return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ) - end - return full_loss_function + return pde_and_bcs_loss_function end -function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn::PINN, - strategy::AbstractTrainingAlg, pfs, cord_branch_net; - additional_loss=Sophon.null_additional_loss, - derivative=finitediff) +function build_loss_function(pde_system::ParametricPDESystem, pinn::PINN, + strategy::AbstractTrainingAlg, cord_branch_net; + derivative=finitediff) (; eqs, bcs, ivs, dvs, pvs) = pde_system (; phi, init_params) = pinn @@ -121,10 +107,10 @@ function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn:: bc_integration_vars = NeuralPDE.get_integration_variables(map(first, bcs), dict_indvars, dict_depvars) - pinnrep = (; eqs, bcs, additional_loss, depvars, indvars, dict_indvars, dict_depvars, - dict_depvar_input, dict_pmdepvars, dict_pmdepvar_input, multioutput, pvs, - init_params, pinn, derivative, strategy, pde_indvars, bc_indvars, - pde_integration_vars, bc_integration_vars, fdtype=Float64, cord_branch_net, + pinnrep = (; eqs, bcs, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input, + dict_pmdepvars, dict_pmdepvar_input, multioutput, pvs, init_params, pinn, + derivative, strategy, pde_indvars, bc_indvars, pde_integration_vars, + bc_integration_vars, fdtype=Float64, cord_branch_net, eq_params=SciMLBase.NullParameters()) datafree_pde_loss_functions = Tuple(build_loss_function(pinnrep, first(eq), i) @@ -139,18 +125,14 @@ function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn:: pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions, datafree_bc_loss_functions) - - function full_loss_function(θ, p) - return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ) - end - return full_loss_function + return pde_and_bcs_loss_function end """ discretize(pde_system::PDESystem, pinn::PINN, sampler::PINNSampler, strategy::AbstractTrainingAlg; additional_loss) -Convert the PDESystem into an `OptimizationProblem`. You will have access to each loss function +Convert the PDESystem into an `OptimizationProblem`. You will have access to each loss function `Sophon.residual_function_1`, `Sophon.residual_function_2`... after calling this function. """ function discretize(pde_system, pinn::PINN, sampler::PINNSampler, @@ -158,13 +140,17 @@ function discretize(pde_system, pinn::PINN, sampler::PINNSampler, additional_loss=Sophon.null_additional_loss, derivative=finitediff, adtype=Optimization.AutoZygote()) datasets = sample(pde_system, sampler, strategy) - datasets = pinn.init_params isa AbstractGPUComponentVector ? + init_params = _ComponentArray(pinn.init_params) + datasets = init_params isa AbstractGPUComponentVector ? map(Base.Fix1(adapt, CuArray), datasets) : datasets - loss_function = get_datafree_pinn_loss_function(pde_system, pinn, strategy; - additional_loss=additional_loss, + pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy; derivative=derivative) - f = OptimizationFunction(loss_function, adtype) - return Optimization.OptimizationProblem(f, pinn.init_params, datasets) + + function full_loss_function(θ, p) + return pde_and_bcs_loss_function(θ, p) + additional_loss(pinn.phi, θ) + end + f = OptimizationFunction(full_loss_function, adtype) + return Optimization.OptimizationProblem(f, init_params, datasets) end function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSampler, @@ -173,18 +159,20 @@ function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSa additional_loss=Sophon.null_additional_loss, derivative=finitediff, adtype=Optimization.AutoZygote()) datasets = sample(pde_system, sampler, strategy) - datasets = pinn.init_params isa AbstractGPUComponentVector ? + init_params = _ComponentArray(pinn.init_params) + datasets = init_params isa AbstractGPUComponentVector ? map(Base.Fix1(adapt, CuArray), datasets) : datasets pfs = sample(functionsampler) cord_branch_net = cord_branch_net isa Union{AbstractVector, StepRangeLen} ? [cord_branch_net] : cord_branch_net - loss_function = get_datafree_pinn_loss_function(pde_system, pinn, strategy, pfs, - cord_branch_net; - additional_loss=additional_loss, - derivative=derivative) - f = OptimizationFunction(loss_function, adtype) + pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy, + cord_branch_net; derivative=derivative) + function full_loss_function(θ, p, pfs) + return pde_and_bcs_loss_function(θ, p, pfs) + additional_loss(pinn.phi, θ) + end + f = OptimizationFunction(full_loss_function, adtype) p = PINOParameterHandler(datasets, pfs) - return Optimization.OptimizationProblem(f, pinn.init_params, p) + return Optimization.OptimizationProblem(f, init_params, p) end diff --git a/src/compact/NeuralPDE/pinn_types.jl b/src/compact/NeuralPDE/pinn_types.jl index f9eb6ad7..bc666d2e 100644 --- a/src/compact/NeuralPDE/pinn_types.jl +++ b/src/compact/NeuralPDE/pinn_types.jl @@ -19,21 +19,21 @@ The default element type of the parameters is `Float64`. ```julia using Random rng = Random.default_rng() -Random.seed!(rng, 0) +Random.seed!(rng, 0)d ``` - and pass `rng` to `PINN` as + +and pass `rng` to `PINN` as + ```julia using Sophon -chain = FullyConnected((1,6,6,1), sin); +chain = FullyConnected((1, 6, 6, 1), sin); # sinple dependent varibale pinn = PINN(chain, rng); # multiple dependent varibales -pinn = PINN(rng; - a = chain, - b = chain); +pinn = PINN(rng; a=chain, b=chain); ``` """ struct PINN{PHI, P} @@ -47,22 +47,20 @@ end function PINN(chain::NamedTuple, rng::AbstractRNG=Random.default_rng()) phi = map(m -> ChainState(m, rng), chain) - init_params = ComponentArray(initialparameters(rng, phi)) .|> Float64 + init_params = Lux.fmap(float64, initialparameters(rng, phi)) return PINN{typeof(phi), typeof(init_params)}(phi, init_params) end function PINN(chain::AbstractExplicitLayer, rng::AbstractRNG=Random.default_rng()) phi = ChainState(chain, rng) - init_params = ComponentArray(initialparameters(rng, phi)) .|> Float64 + init_params = Lux.fmap(float64, initialparameters(rng, phi)) return PINN{typeof(phi), typeof(init_params)}(phi, init_params) end function initialparameters(rng::AbstractRNG, pinn::PINN) - init_params = initialparameters(rng, pinn.phi) - init_params = ComponentArray(init_params) .|> Float64 - return init_params + return Lux.fmap(float64, initialparameters(rng, pinn.phi)) end """ @@ -212,7 +210,8 @@ struct ParametricPDESystem pvs::Vector end -function ParametricPDESystem(eq::Pair{<:Symbolics.Equation, <:DomainSets.Domain}, bcs, ivs, dvs, pvs) +function ParametricPDESystem(eq::Pair{<:Symbolics.Equation, <:DomainSets.Domain}, bcs, ivs, + dvs, pvs) return PDESystem([eq], bcs, ivs, dvs, pvs) end @@ -234,8 +233,8 @@ function Base.show(io::IO, ::MIME"text/plain", sys::ParametricPDESystem) end mutable struct PINOParameterHandler - cords - fs + cords::Any + fs::Any end get_local_ps(p::PINOParameterHandler) = p.cords diff --git a/src/compact/NeuralPDE/pinnsampler.jl b/src/compact/NeuralPDE/pinnsampler.jl index 2bd2a9c4..b4b17d71 100644 --- a/src/compact/NeuralPDE/pinnsampler.jl +++ b/src/compact/NeuralPDE/pinnsampler.jl @@ -29,7 +29,7 @@ function QuasiRandomSampler(pde_points, bcs_points=pde_points; sampling_alg=Sobo sampling_alg) end -function sample(pde::NeuralPDE.PDESystem, sampler::QuasiRandomSampler, strategy) +function sample(pde::NeuralPDE.PDESystem, sampler::QuasiRandomSampler, strategy=nothing) (; pde_points, bcs_points, sampling_alg) = sampler pde_bounds, bcs_bounds = get_bounds(pde) @@ -47,7 +47,7 @@ function sample(pde::NeuralPDE.PDESystem, sampler::QuasiRandomSampler, strategy) return [pde_datasets; boundary_datasets] end -function sample(pde, sampler::QuasiRandomSampler, strategy) +function sample(pde, sampler::QuasiRandomSampler, strategy=nothing) (; pde_points, bcs_points, sampling_alg) = sampler (; eqs, bcs) = pde @@ -65,7 +65,7 @@ function sample(pde, sampler::QuasiRandomSampler, strategy) end function sample(pde::NeuralPDE.PDESystem, sampler::QuasiRandomSampler{P, B, SobolSample}, - strategy) where {P, B} + strategy=nothing) where {P, B} (; pde_points, bcs_points) = sampler pde_bounds, bcs_bounds = get_bounds(pde) diff --git a/src/compact/NeuralPDE/training_strategies.jl b/src/compact/NeuralPDE/training_strategies.jl index 613f1b03..6f77e224 100644 --- a/src/compact/NeuralPDE/training_strategies.jl +++ b/src/compact/NeuralPDE/training_strategies.jl @@ -43,8 +43,7 @@ function scalarize(weights::NTuple{N, <:Real}, datafree_loss_functions::Tuple) w abs2.($(datafree_loss_functions[1])(local_ps[1], θ, gobal_ps)))) for i in 2:N ex = :(mean($(weights[i]) .* - abs2.($(datafree_loss_functions[i])(local_ps[$i], θ, gobal_ps))) + - $ex) + abs2.($(datafree_loss_functions[i])(local_ps[$i], θ, gobal_ps))) + $ex) end push!(body.args, ex) loss_f = Expr(:function, Expr(:call, :(pinn_loss_function), :θ, :pp), body) diff --git a/src/compact/NeuralPDE/utils.jl b/src/compact/NeuralPDE/utils.jl index c8df7100..0864f448 100644 --- a/src/compact/NeuralPDE/utils.jl +++ b/src/compact/NeuralPDE/utils.jl @@ -70,13 +70,13 @@ end end end) ``` """ -function build_symbolic_loss_function(pinnrep::NamedTuple{names}, eq::Symbolics.Equation) where names +function build_symbolic_loss_function(pinnrep::NamedTuple{names}, + eq::Symbolics.Equation) where {names} (; depvars, dict_depvars, dict_depvar_input, derivative, multioutput, dict_indvars) = pinnrep loss_function, pos, values = parse_equation(pinnrep, eq) this_eq_pair = pair(eq, depvars, dict_depvar_input) - this_eq_indvars = unique(vcat([getindex(this_eq_pair, v) - for v in keys(this_eq_pair)]...)) + this_eq_indvars = unique(vcat([getindex(this_eq_pair, v) for v in keys(this_eq_pair)]...)) vars = :(cord, θ, pfs) ex = Expr(:block) @@ -86,9 +86,17 @@ function build_symbolic_loss_function(pinnrep::NamedTuple{names}, eq::Symbolics. push!(ex.args, Expr(:(=), :deeponet, pinn.phi)) push!(ex.args, Expr(:(=), :derivative, derivative)) push!(ex.args, Expr(:(=), :cord_branch_net, cord_branch_net)) - push!(ex.args, Expr(:(=), :(get_pfs_output(x::AbstractMatrix)), :(ChainRulesCore.ignore_derivatives(mapreduce(f -> f.(x), vcat, pfs))))) - push!(ex.args, Expr(:(=), :(get_pfs_output(x::AbstractVector...)), :(ChainRulesCore.ignore_derivatives(mapreduce(f -> reshape(f.(x...), 1, :), vcat, pfs))))) - push!(ex.args, Expr(:(=), :(branch_net_input), :(transpose(get_pfs_output(cord_branch_net...))))) + push!(ex.args, + Expr(:(=), :(get_pfs_output(x::AbstractMatrix)), + :(ChainRulesCore.ignore_derivatives(mapreduce(f -> f.(x), vcat, pfs))))) + push!(ex.args, + Expr(:(=), :(get_pfs_output(x::AbstractVector...)), + :(ChainRulesCore.ignore_derivatives(mapreduce(f -> reshape(f.(x...), 1, + :), vcat, + pfs))))) + push!(ex.args, + Expr(:(=), :(branch_net_input), + :(transpose(get_pfs_output(cord_branch_net...))))) push!(ex.args, Expr(:(=), :(phi(x_, θ_)), :(deeponet((branch_net_input, x_), θ_)))) else @@ -141,11 +149,10 @@ function build_symbolic_loss_function(pinnrep::NamedTuple{names}, eq::Symbolics. vcat_expr = Expr(:block, :($(eq_pair_expr...))) vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) - indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex vars_eq = Expr(:(=), NeuralPDE.build_expr(:tuple, left_arg_pairs), - NeuralPDE.build_expr(:tuple, right_arg_pairs)) + NeuralPDE.build_expr(:tuple, right_arg_pairs)) let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) push!(ex.args, let_ex) @@ -155,7 +162,8 @@ end function build_loss_function(pinnrep::NamedTuple, eq::Symbolics.Equation, i) vars, ex = build_symbolic_loss_function(pinnrep, eq) expr = Expr(:function, - Expr(:call, Symbol(:residual_function_, i), vars.args[1], vars.args[2], :($(Expr(:kw, vars.args[3], :nothing)))), ex) + Expr(:call, Symbol(:residual_function_, i), vars.args[1], vars.args[2], + :($(Expr(:kw, vars.args[3], :nothing)))), ex) return eval(expr) end @@ -229,12 +237,14 @@ function parse_equation(pinnrep::NamedTuple, eq) end end -function is_periodic_bc(bcs::Vector{<:Symbolics.Equation}, eq, depvars, left_expr::Expr, right_expr::Expr) +function is_periodic_bc(bcs::Vector{<:Symbolics.Equation}, eq, depvars, left_expr::Expr, + right_expr::Expr) eq ∉ bcs && return false return left_expr.args[1] ∈ depvars && left_expr.args[1] === right_expr.args[1] end -function is_periodic_bc(bcs::Vector{<:Pair{<:Symbolics.Equation, <:DomainSets.Domain}}, eq, depvars, left_expr::Expr, right_expr::Expr) +function is_periodic_bc(bcs::Vector{<:Pair{<:Symbolics.Equation, <:DomainSets.Domain}}, eq, + depvars, left_expr::Expr, right_expr::Expr) eq ∉ bcs[1] && return false return left_expr.args[1] ∈ depvars && left_expr.args[1] === right_expr.args[1] end @@ -260,10 +270,11 @@ function transform_expression(pinnrep::NamedTuple, ex) return ex end -function _transform_expression(pinnrep::NamedTuple{names}, ex::Expr) where names +function _transform_expression(pinnrep::NamedTuple{names}, ex::Expr) where {names} (; indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, multioutput, fdtype) = pinnrep fdtype = fdtype - dict_pmdepvars = :dict_pmdepvars in names ? pinnrep.dict_pmdepvars : Dict{Symbol, Symbol}() + dict_pmdepvars = :dict_pmdepvars in names ? pinnrep.dict_pmdepvars : + Dict{Symbol, Symbol}() _args = ex.args for (i, e) in enumerate(_args) if !(e isa Expr) @@ -368,23 +379,23 @@ function finitediff(phi, x, εs, order, θ) end @inline function finitediff(phi, x, ε::AbstractVector{T}, ::Val{1}, θ, - h::T) where {T <: AbstractFloat} + h::T) where {T <: AbstractFloat} return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* h ./ 2 end @inline function finitediff(phi, x, ε::AbstractVector{T}, ::Val{2}, θ, - h::T) where {T <: AbstractFloat} + h::T) where {T <: AbstractFloat} return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* h^2 end @inline function finitediff(phi, x, ε::AbstractVector{T}, ::Val{3}, θ, - h::T) where {T <: AbstractFloat} + h::T) where {T <: AbstractFloat} return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - phi(x .- 2 .* ε, θ)) .* h^3 ./ 2 end @inline function finitediff(phi, x, ε::AbstractVector{T}, ::Val{4}, θ, - h::T) where {T <: AbstractFloat} + h::T) where {T <: AbstractFloat} return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .- 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* h^4 end diff --git a/src/compact/componentarrays.jl b/src/compact/componentarrays.jl index c9621cea..7a95ba02 100644 --- a/src/compact/componentarrays.jl +++ b/src/compact/componentarrays.jl @@ -9,3 +9,7 @@ const AbstractGPUComponentMatrix{T, Ax} = ComponentArray{T, 2, Ax} const AbstractGPUComponentVecorMat{T, Ax} = Union{AbstractGPUComponentVector{T, Ax}, AbstractGPUComponentMatrix{T, Ax}} + +function _ComponentArray(nt::NamedTuple) + return isongpu(nt) ? adapt(CuArray, ComponentArray(cpu(nt))) : ComponentArray(nt) +end diff --git a/src/layers/nets.jl b/src/layers/nets.jl index 59410ab0..d051a2c2 100644 --- a/src/layers/nets.jl +++ b/src/layers/nets.jl @@ -86,7 +86,7 @@ x → [FourierFeature(x); x] → PINNAttention ## Arguments - `in_dims`: The input dimension. - + + `out_dims`: The output dimension. + `activation`: The activation function. + `std`: See [`FourierFeature`](@ref). diff --git a/src/utils.jl b/src/utils.jl index f82c2cc0..fcdbad3a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -85,3 +85,9 @@ end end ChainRulesCore.@non_differentiable init_normal(::Any...) + +function isongpu(nt::NamedTuple) + return any(x -> x isa AbstractGPUArray, Lux.fcollect(nt)) +end + +float64 = Base.Fix1(convert, AbstractArray{Float64}) diff --git a/test/runtests.jl b/test/runtests.jl index f64fac6d..b2edf86b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -254,8 +254,8 @@ rng = Random.default_rng() @test_nowarn PINN(u=chain, p=chain) @test_nowarn PINN(chain, rng) @test_nowarn PINN(rng; u=chain, p=chain) - @test Lux.initialparameters(rng, PINN(chain)) isa Lux.ComponentArray - @test Lux.initialparameters(rng, PINN(u=chain, p=chain)) isa Lux.ComponentArray + @test Lux.initialparameters(rng, PINN(chain)) isa NamedTuple + @test Lux.initialparameters(rng, PINN(; u=chain, p=chain)) isa NamedTuple end end end