Skip to content

Commit

Permalink
Merge pull request #190 from YichengDWu/ResNet
Browse files Browse the repository at this point in the history
Add a convenient ResNet constructor
  • Loading branch information
YichengDWu authored Mar 18, 2023
2 parents 1760172 + 569e1b4 commit 7ebac78
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/Sophon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ function __init__()
end

export gaussian, quadratic, laplacian, expsin, multiquadratic
export FourierFeature, TriplewiseFusion, FullyConnected, Sine, RBF, DiscreteFourierFeature,
ConstantFunction, ScalarLayer, SplitFunction, FactorizedDense
export FourierFeature, TriplewiseFusion, FullyConnected, ResNet, Sine, RBF,
DiscreteFourierFeature, ConstantFunction, ScalarLayer, SplitFunction, FactorizedDense
export PINNAttention, FourierNet, FourierAttention, Siren, FourierFilterNet, BACON
export DeepONet
export PINN, symbolic_discretize, discretize, QuasiRandomSampler, NonAdaptiveTraining,
Expand Down
6 changes: 3 additions & 3 deletions src/diff/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ function AdaptiveTraining(pde_weights::Function, bcs_weights::NTuple{N, <:Real})
_bcs_weights)
end

function AdaptiveTraining(pde_weights::Tuple{Vararg{<:Function}}, bcs_weights::Int)
function AdaptiveTraining(pde_weights::Tuple{Vararg{Function}}, bcs_weights::Int)
_bcs_weights = (phi, cord, θ) -> bcs_weights
return AdaptiveTraining{typeof(pde_weights), typeof(_bcs_weights)}(pde_weights,
_bcs_weights)
end

function AdaptiveTraining(pde_weights::Tuple{Vararg{<:Function}},
function AdaptiveTraining(pde_weights::Tuple{Vararg{Function}},
bcs_weights::NTuple{N, <:Real}) where {N}
_bcs_weights = map(w -> (phi, cord, θ) -> w, bcs_weights)
return AdaptiveTraining{typeof(pde_weights), typeof(_bcs_weights)}(pde_weights,
Expand All @@ -103,7 +103,7 @@ function scalarize(strategy::AdaptiveTraining, phi, datafree_pde_loss_function,
return f
end

function scalarize(phi, weights::Tuple{Vararg{<:Function}}, datafree_loss_function::Tuple)
function scalarize(phi, weights::Tuple{Vararg{Function}}, datafree_loss_function::Tuple)
N = length(datafree_loss_function)
body = Expr(:block)
push!(body.args, Expr(:(=), :local_ps, :(get_local_ps(pp))))
Expand Down
91 changes: 91 additions & 0 deletions src/layers/nets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,97 @@ end
return :(Chain($(layers...)))
end

"""
ResNet(layer_sizes::NTuple{N, Int}, activation; outermost = true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
ResNet(in_dims::Int, out_dims::Int, activation::Function;
hidden_dims::Int, num_layers::Int, outermost=true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
Create fully connected layers.
## Arguments
- `layer_sizes`: Number of dimensions of each layer.
- `hidden_dims`: Number of hidden dimensions.
- `num_layers`: Number of layers.
- `activation`: Activation function.
## Keyword Arguments
- `outermost`: Whether to use activation function for the last layer. If `false`, the activation function is applied
to the output of the last layer.
- `init_weight`: Initialization method for the weights.
## Example
```julia
julia> ResNet((1, 12, 24, 32), relu)
Chain(
layer_1 = Dense(1 => 12, relu), # 24 parameters
layer_2 = SkipConnection(
Dense(12 => 24, relu), # 312 parameters
+
),
layer_3 = Dense(24 => 32), # 800 parameters
) # Total: 1_136 parameters,
# plus 0 states, summarysize 48 bytes.
julia> ResNet(1, 10, relu; hidden_dims=20, num_layers=3)
Chain(
layer_1 = Dense(1 => 20, relu), # 40 parameters
layer_2 = SkipConnection(
Dense(20 => 20, relu), # 420 parameters
+
),
layer_3 = SkipConnection(
Dense(20 => 20, relu), # 420 parameters
+
),
layer_4 = Dense(20 => 10), # 210 parameters
) # Total: 1_090 parameters,
# plus 0 states, summarysize 64 bytes.
```
"""
function ResNet(layer_sizes::NTuple{N, T}, activation::Function; outermost::Bool=true,
init_bias=zeros32,
init_weight::Function=kaiming_uniform(activation)) where {N, T <: Int}
return ResNet(layer_sizes, activation, Val(outermost); init_weight=init_weight,
init_bias=init_bias)
end

function ResNet(in_dims::Int, out_dims::Int, activation::Function; hidden_dims::Int,
num_layers::Int, outermost::Bool=true,
init_weight::Function=kaiming_uniform(activation), init_bias=zeros32)
return ResNet((in_dims, ntuple(_ -> hidden_dims, num_layers)..., out_dims), activation,
Val(outermost); init_weight=init_weight, init_bias=init_bias)
end

@generated function ResNet(layer_sizes::NTuple{N, T}, activation::Function, ::Val{F};
init_weight, init_bias) where {N, T <: Int, F}
N == 2 &&
return :(Dense(layer_sizes[1], layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias))
function get_layer(i)
return :(SkipConnection(Dense(layer_sizes[$i] => layer_sizes[$(i + 1)], activation;
init_weight=init_weight, init_bias=init_bias), +))
end
layers = [
:(Dense(layer_sizes[1] => layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias)),
]
append!(layers, [get_layer(i) for i in 2:(N - 2)])
append!(layers,
F ?
[
:(Dense(layer_sizes[$(N - 1)] => layer_sizes[$N]; init_weight=init_weight,
init_bias=init_bias)),
] : [get_layer(N - 1)])
return :(Chain($(layers...)))
end

struct MultiplicativeFilterNet{F, L, O} <:
AbstractExplicitContainerLayer{(:filters, :linear_layers, :output_layer)}
filters::F
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ rng = Random.default_rng()
outermost=false)
@test fc5.layers[end].activation == sin
end

@testset "Sine" begin
# first layer
s = Sine(2, 3; omega=30.0f0)
Expand Down

0 comments on commit 7ebac78

Please sign in to comment.