Skip to content

Commit

Permalink
Merge pull request #192 from YichengDWu/fast
Browse files Browse the repository at this point in the history
Disable fast activation conversion for the sake of accuracy
  • Loading branch information
YichengDWu authored Mar 18, 2023
2 parents 7ebac78 + 3b7c8d2 commit 124b292
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
72 changes: 43 additions & 29 deletions src/layers/nets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,14 @@ end

"""
FullyConnected(layer_sizes::NTuple{N, Int}, activation; outermost = true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
init_weight=kaiming_uniform(activation),
init_bias=zeros32,
allow_fast_activation=false)
FullyConnected(in_dims::Int, out_dims::Int, activation::Function;
hidden_dims::Int, num_layers::Int, outermost=true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
init_weight=kaiming_uniform(activation),
init_bias=zeros32,
allow_fast_activation=false)
Create fully connected layers.
Expand All @@ -271,6 +273,8 @@ Create fully connected layers.
- `outermost`: Whether to use activation function for the last layer. If `false`, the activation function is applied
to the output of the last layer.
- `init_weight`: Initialization method for the weights.
- `allow_fast_activation`: If true, then certain activations can be approximated with a
faster version. The new activation function will be given by NNlib.fast_act(activation)
## Example
Expand All @@ -295,52 +299,56 @@ Chain(
"""
function FullyConnected(layer_sizes::NTuple{N, T}, activation::Function;
outermost::Bool=true, init_bias=zeros32,
init_weight::Function=kaiming_uniform(activation)) where {N,
T <: Int}
init_weight::Function=kaiming_uniform(activation),
allow_fast_activation::Bool=false) where {N, T <: Int}
return FullyConnected(layer_sizes, activation, Val(outermost); init_weight=init_weight,
init_bias=init_bias)
init_bias=init_bias, allow_fast_activation=allow_fast_activation)
end

function FullyConnected(in_dims::Int, out_dims::Int, activation::Function; hidden_dims::Int,
num_layers::Int, outermost::Bool=true,
init_weight::Function=kaiming_uniform(activation),
init_bias=zeros32)
init_bias=zeros32, allow_fast_activation::Bool=false)
return FullyConnected((in_dims, ntuple(_ -> hidden_dims, num_layers)..., out_dims),
activation, Val(outermost); init_weight=init_weight,
init_bias=init_bias)
init_bias=init_bias, allow_fast_activation=allow_fast_activation)
end

@generated function FullyConnected(layer_sizes::NTuple{N, T}, activation::Function,
::Val{F}; init_weight, init_bias) where {N, T <: Int, F}
::Val{F}; init_weight, init_bias,
allow_fast_activation) where {N, T <: Int, F}
N == 2 &&
return :(Dense(layer_sizes[1], layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias))
init_bias=init_bias, allow_fast_activation=allow_fast_activation))
function get_layer(i)
return :(Dense(layer_sizes[$i] => layer_sizes[$(i + 1)], activation;
init_weight=init_weight, init_bias=init_bias))
init_weight=init_weight, init_bias=init_bias,
allow_fast_activation=allow_fast_activation))
end
layers = [
:(Dense(layer_sizes[1] => layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias)),
init_bias=init_bias, allow_fast_activation=allow_fast_activation)),
]
append!(layers, [get_layer(i) for i in 2:(N - 2)])
append!(layers,
F ?
[
:(Dense(layer_sizes[$(N - 1)] => layer_sizes[$N]; init_weight=init_weight,
init_bias=init_bias)),
init_bias=init_bias, allow_fast_activation=allow_fast_activation)),
] : [get_layer(N - 1)])
return :(Chain($(layers...)))
end

"""
ResNet(layer_sizes::NTuple{N, Int}, activation; outermost = true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
ResNet(layer_sizes::NTuple{N, Int}, activation; outermost=true,
init_weight=kaiming_uniform(activation),
init_bias=zeros32,
allow_fast_activation=false)
ResNet(in_dims::Int, out_dims::Int, activation::Function;
hidden_dims::Int, num_layers::Int, outermost=true,
init_weight = kaiming_uniform(activation),
init_bias = zeros32)
init_weight=kaiming_uniform(activation),
init_bias=zeros32,
allow_fast_activation=false)
Create fully connected layers.
Expand All @@ -356,6 +364,8 @@ Create fully connected layers.
- `outermost`: Whether to use activation function for the last layer. If `false`, the activation function is applied
to the output of the last layer.
- `init_weight`: Initialization method for the weights.
- `allow_fast_activation`: If true, then certain activations can be approximated with a
faster version. The new activation function will be given by NNlib.fast_act(activation)
## Example
Expand Down Expand Up @@ -388,38 +398,42 @@ Chain(
```
"""
function ResNet(layer_sizes::NTuple{N, T}, activation::Function; outermost::Bool=true,
init_bias=zeros32,
init_weight::Function=kaiming_uniform(activation)) where {N, T <: Int}
init_bias=zeros32, init_weight::Function=kaiming_uniform(activation),
allow_fast_activation::Bool=false) where {N, T <: Int}
return ResNet(layer_sizes, activation, Val(outermost); init_weight=init_weight,
init_bias=init_bias)
init_bias=init_bias, allow_fast_activation=allow_fast_activation)
end

function ResNet(in_dims::Int, out_dims::Int, activation::Function; hidden_dims::Int,
num_layers::Int, outermost::Bool=true,
init_weight::Function=kaiming_uniform(activation), init_bias=zeros32)
init_weight::Function=kaiming_uniform(activation), init_bias=zeros32,
allow_fast_activation::Bool=false)
return ResNet((in_dims, ntuple(_ -> hidden_dims, num_layers)..., out_dims), activation,
Val(outermost); init_weight=init_weight, init_bias=init_bias)
Val(outermost); init_weight=init_weight, init_bias=init_bias,
allow_fast_activation=allow_fast_activation)
end

@generated function ResNet(layer_sizes::NTuple{N, T}, activation::Function, ::Val{F};
init_weight, init_bias) where {N, T <: Int, F}
init_weight, init_bias,
allow_fast_activation) where {N, T <: Int, F}
N == 2 &&
return :(Dense(layer_sizes[1], layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias))
init_bias=init_bias, allow_fast_activation=allow_fast_activation))
function get_layer(i)
return :(SkipConnection(Dense(layer_sizes[$i] => layer_sizes[$(i + 1)], activation;
init_weight=init_weight, init_bias=init_bias), +))
init_weight=init_weight, init_bias=init_bias,
allow_fast_activation=allow_fast_activation), +))
end
layers = [
:(Dense(layer_sizes[1] => layer_sizes[2], activation; init_weight=init_weight,
init_bias=init_bias)),
init_bias=init_bias, allow_fast_activation=allow_fast_activation)),
]
append!(layers, [get_layer(i) for i in 2:(N - 2)])
append!(layers,
F ?
[
:(Dense(layer_sizes[$(N - 1)] => layer_sizes[$N]; init_weight=init_weight,
init_bias=init_bias)),
init_bias=init_bias, allow_fast_activation=allow_fast_activation)),
] : [get_layer(N - 1)])
return :(Chain($(layers...)))
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ rng = Random.default_rng()
@testset "Constructors" begin
model = DeepONet((3, 5, 4), relu, (2, 6, 4, 4), tanh)
@test model.branch_net.layers[end].activation == identity
@test model.trunk_net.layers[end].activation == tanh_fast
@test model.trunk_net.layers[end].activation == tanh

branch = Chain(Dense(2, 3), Dense(3, 4))
trunk = Chain(Dense(3, 4), Dense(4, 5))
Expand Down

0 comments on commit 124b292

Please sign in to comment.