Skip to content

Commit

Permalink
Merge pull request #177 from YichengDWu/minibatch
Browse files Browse the repository at this point in the history
allow customizing training loop given the loss function
  • Loading branch information
YichengDWu authored Oct 21, 2022
2 parents 0c11473 + f09ed3b commit 745fcdb
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 123 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/SchrödingerEquation.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function train(pde_system, prob, sampler, strategy, resample_period = 500, n=10)
res = Optimization.solve(prob, bfgs; maxiters=2000)
for i in 1:n
data = Sophon.sample(pde_system, sampler, strategy)
data = Sophon.sample(pde_system, sampler)
prob = remake(prob; u0=res.u, p=data)
res = Optimization.solve(prob, bfgs; maxiters=resample_period)
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/allen_cahn.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function train(allen, prob, sampler, strategy)
for tmax in [0.5, 0.75, 1.0]
allen.domain[2] = t ∈ 0.0..tmax
data = Sophon.sample(allen, sampler, strategy)
data = Sophon.sample(allen, sampler)
prob = remake(prob; u0=res.u, p=data)
res = Optimization.solve(prob, bfgs; maxiters=2000)
end
Expand Down
59 changes: 31 additions & 28 deletions docs/src/tutorials/burgers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Sophon, ModelingToolkit
using DomainSets
using DomainSets: ×
using Optimization, OptimizationOptimJL, Optimisers
using Optimization, OptimizationOptimJL
import OptimizationFlux: Adam
using Interpolations, GaussianRandomFields
using Setfield

Expand All @@ -12,32 +13,34 @@ Dₜ = Differential(t)
Dₓ² = Dₓ^2

ν = 0.001
eq = Dₜ(u(x,t)) + u(x,t) * Dₓ(u(x,t)) ~ ν * Dₓ²(u(x,t))
eq = Dₜ(u(x, t)) + u(x, t) * Dₓ(u(x, t)) ~ ν * Dₓ²(u(x, t))
domain = (0.0 .. 1.0) × (0.0 .. 1.0)
eq = eq => domain

bcs = [(u(0.0, t) ~ u(1.0, t)) => (0.0 .. 1.0),
(u(x, t) ~ a(x)) => (0.0 .. 1.0) × (0.0 .. 0.0)]
bcs = [
(u(0.0, t) ~ u(1.0, t)) => (0.0 .. 1.0),
(u(x, t) ~ a(x)) => (0.0 .. 1.0) × (0.0 .. 0.0),
]

boundary = 0.0 .. 1.0

Burgers = Sophon.ParametricPDESystem([eq], bcs, [t, x], [u(x,t)], [a(x)])
Burgers = Sophon.ParametricPDESystem([eq], bcs, [t, x], [u(x, t)], [a(x)])

chain = DeepONet((50, 50, 50, 50), tanh, (2, 50, 50, 50, 50), tanh)
pinn = PINN(chain)
sampler = QuasiRandomSampler(500, 50)
strategy = NonAdaptiveTraining()

struct MyFuncSampler <: Sophon.FunctionSampler
pts
grf
n
struct MyFuncSampler <: Sophon.FunctionSampler
pts::Any
grf::Any
n::Any
end

function MyFuncSampler(pts, n)
cov = CovarianceFunction(1, Whittle(.1))
grf = GaussianRandomField(cov, KarhunenLoeve(5), pts)
return MyFuncSampler(pts, grf, n)
cov = CovarianceFunction(1, Whittle(0.1))
grf = GaussianRandomField(cov, KarhunenLoeve(5), pts)
return MyFuncSampler(pts, grf, n)
end

function Sophon.sample(sampler::MyFuncSampler)
Expand All @@ -48,41 +51,41 @@ function Sophon.sample(sampler::MyFuncSampler)
push!(ys, y)
end
return ys
end
end

cord_branch_net = range(0.0, 1.0; length=50)

cord_branch_net = range(0.0, 1.0, length=50)
fsampler = MyFuncSampler(cord_branch_net, 10)

fsampler = MyFuncSampler(cord_branch_net, 10)

prob = Sophon.discretize(Burgers, pinn, sampler, strategy, fsampler, cord_branch_net)

function callback(p,l)
function callback(p, l)
println("Loss: $l")
return false
end

@time res = Optimization.solve(prob, BFGS(); maxiters=1000, callback=callback)

using ProgressMeter
n = 20000

n = 20000
k = 10
pg = Progress(n; showspeed=true)
function callback(p,l)
ProgressMeter.next!(pg; showvalues = [(:loss, l)])

function callback(p, l)
ProgressMeter.next!(pg; showvalues=[(:loss, l)])
return false
end
adam = Adam()

adam = Adam()
for i in 1:k
cords = Sophon.sample(Burgers, sampler, strategy)
fs = Sophon.sample(fsampler)
p = Sophon.PINOParameterHandler(cords, fs)
prob = remake(prob; u0 = res.u, p = p)
res = Optimization.solve(prob, adam; maxiters= n ÷ k, callback=callback)
prob = remake(prob; u0=res.u, p=p)
res = Optimization.solve(prob, adam; maxiters=n ÷ k, callback=callback)
end

using CairoMakie

phi = pinn.phi
Expand All @@ -93,4 +96,4 @@ f_test(x) = sinpi(2x)
u0 = reshape(f_test.(cord_branch_net), :, 1)
axis = (xlabel="t", ylabel="x", title="Prediction")
u_pred = [sum(pinn.phi((u0, [x, t]), res.u)) for x in xs, t in ts]
fig, ax, hm = heatmap(ts, xs, u_pred', axis=axis, colormap=:jet)
fig, ax, hm = heatmap(ts, xs, u_pred'; axis=axis, colormap=:jet)
2 changes: 1 addition & 1 deletion docs/src/tutorials/sod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ domains = [t ∈ Interval(t_min, t_max), x ∈ Interval(x_min, x_max)]

pinn = PINN(; u=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16),
ρ=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16),
p=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16))
p=FullyConnected(2, 1, tanh; num_layers=4, hidden_dims=16))
sampler = QuasiRandomSampler(1000, 100)

function pde_weights(phi, x, θ)
Expand Down
94 changes: 41 additions & 53 deletions src/compact/NeuralPDE/discretize.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@

function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss,
derivative=finitediff)
function build_loss_function(pde_system::NeuralPDE.PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg; derivative=finitediff)
(; eqs, bcs, domain, ps, defaults, indvars, depvars) = pde_system
(; phi, init_params) = pinn

Expand All @@ -26,11 +24,10 @@ function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn::
bc_integration_vars = NeuralPDE.get_integration_variables(bcs, dict_indvars,
dict_depvars)

pinnrep = (; eqs, bcs, domain, ps, defaults, default_p, additional_loss, depvars,
indvars, dict_indvars, dict_depvars, dict_depvar_input, multioutput,
init_params, phi, derivative, strategy, pde_indvars, bc_indvars,
pde_integration_vars, bc_integration_vars, fdtype=Float64,
eq_params=SciMLBase.NullParameters())
pinnrep = (; eqs, bcs, domain, ps, defaults, default_p, depvars, indvars, dict_indvars,
dict_depvars, dict_depvar_input, multioutput, init_params, phi, derivative,
strategy, pde_indvars, bc_indvars, pde_integration_vars, bc_integration_vars,
fdtype=Float64, eq_params=SciMLBase.NullParameters())
integral = get_numeric_integral(pinnrep)
pinnrep = merge(pinnrep, (; integral))

Expand All @@ -46,17 +43,11 @@ function get_datafree_pinn_loss_function(pde_system::NeuralPDE.PDESystem, pinn::

pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions,
datafree_bc_loss_functions)

function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ)
end
return full_loss_function
return pde_and_bcs_loss_function
end

function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss,
derivative=finitediff)
function build_loss_function(pde_system::PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg; derivative=finitediff)
(; eqs, bcs, ivs, dvs) = pde_system
(; phi, init_params) = pinn

Expand All @@ -74,10 +65,10 @@ function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN,
bc_integration_vars = NeuralPDE.get_integration_variables(map(first, bcs), dict_indvars,
dict_depvars)

pinnrep = (; eqs, bcs, additional_loss, depvars, indvars, dict_indvars, dict_depvars,
dict_depvar_input, multioutput, init_params, phi, derivative, strategy,
pde_indvars, bc_indvars, pde_integration_vars, bc_integration_vars,
fdtype=Float64, eq_params=SciMLBase.NullParameters())
pinnrep = (; eqs, bcs, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input,
multioutput, init_params, phi, derivative, strategy, pde_indvars, bc_indvars,
pde_integration_vars, bc_integration_vars, fdtype=Float64,
eq_params=SciMLBase.NullParameters())
integral = nothing
pinnrep = merge(pinnrep, (; integral))

Expand All @@ -93,17 +84,12 @@ function get_datafree_pinn_loss_function(pde_system::PDESystem, pinn::PINN,

pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions,
datafree_bc_loss_functions)

function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ)
end
return full_loss_function
return pde_and_bcs_loss_function
end

function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn::PINN,
strategy::AbstractTrainingAlg, pfs, cord_branch_net;
additional_loss=Sophon.null_additional_loss,
derivative=finitediff)
function build_loss_function(pde_system::ParametricPDESystem, pinn::PINN,
strategy::AbstractTrainingAlg, cord_branch_net;
derivative=finitediff)
(; eqs, bcs, ivs, dvs, pvs) = pde_system
(; phi, init_params) = pinn

Expand All @@ -121,10 +107,10 @@ function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn::
bc_integration_vars = NeuralPDE.get_integration_variables(map(first, bcs), dict_indvars,
dict_depvars)

pinnrep = (; eqs, bcs, additional_loss, depvars, indvars, dict_indvars, dict_depvars,
dict_depvar_input, dict_pmdepvars, dict_pmdepvar_input, multioutput, pvs,
init_params, pinn, derivative, strategy, pde_indvars, bc_indvars,
pde_integration_vars, bc_integration_vars, fdtype=Float64, cord_branch_net,
pinnrep = (; eqs, bcs, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input,
dict_pmdepvars, dict_pmdepvar_input, multioutput, pvs, init_params, pinn,
derivative, strategy, pde_indvars, bc_indvars, pde_integration_vars,
bc_integration_vars, fdtype=Float64, cord_branch_net,
eq_params=SciMLBase.NullParameters())

datafree_pde_loss_functions = Tuple(build_loss_function(pinnrep, first(eq), i)
Expand All @@ -139,32 +125,32 @@ function get_datafree_pinn_loss_function(pde_system::ParametricPDESystem, pinn::

pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions,
datafree_bc_loss_functions)

function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(phi, θ)
end
return full_loss_function
return pde_and_bcs_loss_function
end
"""
discretize(pde_system::PDESystem, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
additional_loss)
Convert the PDESystem into an `OptimizationProblem`. You will have access to each loss function
Convert the PDESystem into an `OptimizationProblem`. You will have access to each loss function
`Sophon.residual_function_1`, `Sophon.residual_function_2`... after calling this function.
"""
function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler, strategy)
datasets = pinn.init_params isa AbstractGPUComponentVector ?
init_params = _ComponentArray(pinn.init_params)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, CuArray), datasets) : datasets
loss_function = get_datafree_pinn_loss_function(pde_system, pinn, strategy;
additional_loss=additional_loss,
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy;
derivative=derivative)
f = OptimizationFunction(loss_function, adtype)
return Optimization.OptimizationProblem(f, pinn.init_params, datasets)

function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(pinn.phi, θ)
end
f = OptimizationFunction(full_loss_function, adtype)
return Optimization.OptimizationProblem(f, init_params, datasets)
end

function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSampler,
Expand All @@ -173,18 +159,20 @@ function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSa
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler, strategy)
datasets = pinn.init_params isa AbstractGPUComponentVector ?
init_params = _ComponentArray(pinn.init_params)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, CuArray), datasets) : datasets

pfs = sample(functionsampler)
cord_branch_net = cord_branch_net isa Union{AbstractVector, StepRangeLen} ?
[cord_branch_net] : cord_branch_net
loss_function = get_datafree_pinn_loss_function(pde_system, pinn, strategy, pfs,
cord_branch_net;
additional_loss=additional_loss,
derivative=derivative)
f = OptimizationFunction(loss_function, adtype)
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
cord_branch_net; derivative=derivative)
function full_loss_function(θ, p, pfs)
return pde_and_bcs_loss_function(θ, p, pfs) + additional_loss(pinn.phi, θ)
end
f = OptimizationFunction(full_loss_function, adtype)

p = PINOParameterHandler(datasets, pfs)
return Optimization.OptimizationProblem(f, pinn.init_params, p)
return Optimization.OptimizationProblem(f, init_params, p)
end
27 changes: 13 additions & 14 deletions src/compact/NeuralPDE/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ The default element type of the parameters is `Float64`.
```julia
using Random
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.seed!(rng, 0)d
```
and pass `rng` to `PINN` as
and pass `rng` to `PINN` as
```julia
using Sophon
chain = FullyConnected((1,6,6,1), sin);
chain = FullyConnected((1, 6, 6, 1), sin);
# sinple dependent varibale
pinn = PINN(chain, rng);
# multiple dependent varibales
pinn = PINN(rng;
a = chain,
b = chain);
pinn = PINN(rng; a=chain, b=chain);
```
"""
struct PINN{PHI, P}
Expand All @@ -47,22 +47,20 @@ end

function PINN(chain::NamedTuple, rng::AbstractRNG=Random.default_rng())
phi = map(m -> ChainState(m, rng), chain)
init_params = ComponentArray(initialparameters(rng, phi)) .|> Float64
init_params = Lux.fmap(float64, initialparameters(rng, phi))

return PINN{typeof(phi), typeof(init_params)}(phi, init_params)
end

function PINN(chain::AbstractExplicitLayer, rng::AbstractRNG=Random.default_rng())
phi = ChainState(chain, rng)
init_params = ComponentArray(initialparameters(rng, phi)) .|> Float64
init_params = Lux.fmap(float64, initialparameters(rng, phi))

return PINN{typeof(phi), typeof(init_params)}(phi, init_params)
end

function initialparameters(rng::AbstractRNG, pinn::PINN)
init_params = initialparameters(rng, pinn.phi)
init_params = ComponentArray(init_params) .|> Float64
return init_params
return Lux.fmap(float64, initialparameters(rng, pinn.phi))
end

"""
Expand Down Expand Up @@ -212,7 +210,8 @@ struct ParametricPDESystem
pvs::Vector
end

function ParametricPDESystem(eq::Pair{<:Symbolics.Equation, <:DomainSets.Domain}, bcs, ivs, dvs, pvs)
function ParametricPDESystem(eq::Pair{<:Symbolics.Equation, <:DomainSets.Domain}, bcs, ivs,
dvs, pvs)
return PDESystem([eq], bcs, ivs, dvs, pvs)
end

Expand All @@ -234,8 +233,8 @@ function Base.show(io::IO, ::MIME"text/plain", sys::ParametricPDESystem)
end

mutable struct PINOParameterHandler
cords
fs
cords::Any
fs::Any
end

get_local_ps(p::PINOParameterHandler) = p.cords
Expand Down
Loading

0 comments on commit 745fcdb

Please sign in to comment.