diff --git a/src/compact/NeuralPDE/pinn_types.jl b/src/compact/NeuralPDE/pinn_types.jl index f5c37735..bc666d2e 100644 --- a/src/compact/NeuralPDE/pinn_types.jl +++ b/src/compact/NeuralPDE/pinn_types.jl @@ -47,21 +47,20 @@ end function PINN(chain::NamedTuple, rng::AbstractRNG=Random.default_rng()) phi = map(m -> ChainState(m, rng), chain) - init_params = Lux.fmap(float, initialparameters(rng, phi)) + init_params = Lux.fmap(float64, initialparameters(rng, phi)) return PINN{typeof(phi), typeof(init_params)}(phi, init_params) end function PINN(chain::AbstractExplicitLayer, rng::AbstractRNG=Random.default_rng()) phi = ChainState(chain, rng) - init_params = Lux.fmap(float, initialparameters(rng, phi)) + init_params = Lux.fmap(float64, initialparameters(rng, phi)) return PINN{typeof(phi), typeof(init_params)}(phi, init_params) end function initialparameters(rng::AbstractRNG, pinn::PINN) - init_params = Lux.fmap(float, initialparameters(rng, pinn.phi)) - return init_params + return Lux.fmap(float64, initialparameters(rng, pinn.phi)) end """ diff --git a/src/utils.jl b/src/utils.jl index 786ba5a3..76ca7e7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -89,3 +89,5 @@ ChainRulesCore.@non_differentiable init_normal(::Any...) function isongpu(nt::NamedTuple) return any(x -> x isa AbstractGPUArray, Lux.fcollect(nt)) end + +float64 = Base.Fix1(Broadcast.broadcast, Float64)