Skip to content

Commit

Permalink
Merge pull request #212 from YichengDWu/broadcast
Browse files Browse the repository at this point in the history
Fix running on GPU
  • Loading branch information
YichengDWu authored May 28, 2023
2 parents e3b12ab + 888e842 commit 27f2e52
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 66 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Expand Down Expand Up @@ -65,14 +66,15 @@ Optimization = "3"
OptimizationOptimisers = "0.1"
ParameterSchedulers = "0.3"
QuasiMonteCarlo = "0.2"
Requires = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLBase = "1"
Sobol = "1"
StaticArraysCore = "1"
StatsBase = "0.33"
Symbolics = "4, 5"
julia = "1.7"
Requires = "1"
julia = "1.8"
StaticArrays = "1.5"

[extras]
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Expand Down
98 changes: 84 additions & 14 deletions ext/SophonTaylorDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

function CRC.rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractMatrix{TaylorScalar{T,N}}) where {N, S <: Number, T}
project_t = CRC.ProjectTo(t)
project_A = CRC.ProjectTo(A)
function gemv_pullback(x̄)
= CRC.unthunk(x̄)
= reinterpret(reshape, T, X̄)
Expand All @@ -49,40 +49,110 @@ function CRC.rrule(::typeof(*), A::AbstractMatrix{S},
end
C
end
dB = CRC.@thunk(project_t(transpose(A)*))
CRC.NoTangent(), dA, dB
dB = CRC.@thunk(transpose(A)*X̄)
CRC.NoTangent(), project_A(dA), dB
end
return A * t, gemv_pullback
end

for N in 1:5
@eval begin
$(Symbol(:broadcasted_make_taylor_, N))(t0,t1) = CRC.@ignore_derivatives broadcast((t0, t1) -> make_taylor(t0, t1, $(Val(N))), t0, t1)

function CRC.rrule(f::typeof($(Symbol(:broadcasted_make_taylor_, N))), x::AbstractVector, y::AbstractVector)
o = f(x, y)
function f_pullback(x̄::AbstractVector{<:TaylorScalar{T}}) where {T}
x = reinterpret(reshape, T, x̄)
return CRC.NoTangent(), x[1, :], x[2, :]
end
return o, f_pullback
end

function CRC.rrule(f::typeof($(Symbol(:broadcasted_make_taylor_, N))), x::AbstractMatrix, y::AbstractVector)
o = f(x, y)
function broadcasted_make_taylor_pullback(x̄::AbstractMatrix{<:TaylorScalar{T}}) where {T}
x = reinterpret(reshape, T, x̄)
return CRC.NoTangent(), x[1, :, :], x[2, :, 1]
end
return o, broadcasted_make_taylor_pullback
end

$(Symbol(:broadcasted_extract_derivative_, N))(t) = CRC.@ignore_derivatives broadcast(Base.Fix2(extract_derivative, $(Val(N))), t)

function CRC.rrule(f::typeof($(Symbol(:broadcasted_extract_derivative_, N))), t::AbstractArray{TaylorScalar{T, L}}) where {T, L}
function broadcasted_extract_derivative_pullback(x̂)
Δ = broadcast(x̂) do d
TaylorScalar{T, L}(ntuple(j -> j === $N ? d : zero(T), Val{L}()))
end
return CRC.NoTangent(), Δ
end
return f(t), broadcasted_extract_derivative_pullback
end
end
end

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{T},
order::Int64) where {T <: Number}
derivative(f, x, l, Val{order + 1}())
end

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{T},
vN::Val{N}) where {T <: Number, N}
t = broadcast((t0, t1) -> make_taylor(t0, t1, vN), x, l)
return extract_derivative(f(t), N)
for N in 1:5
@eval @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{T},
::Val{$N}) where {T <: Number}
t = $(Symbol(:broadcasted_make_taylor_, N))(x, l)
return extract_derivative(f(t), $N)
end
end

@inline extract_derivative(t::TaylorScalar, ::Val{N}) where {N} = value(t)[N]
# batched version
@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T},
vN::Val{N}) where {T <: Number, N}
t = broadcast((t0, t1) -> TaylorDiff.make_taylor(t0, t1, vN), x, l)
return map(Base.Fix2(extract_derivative, vN), f(t))
for N in 1:5
@eval @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T},
::Val{$N}) where {T <: Number}
t = $(Symbol(:broadcasted_make_taylor_, N))(x, l)
return $(Symbol(:broadcasted_extract_derivative_, N))(f(t))
end
end

@inline function taylordiff(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{N}) where {T <: Number, N}
@inline function taylordiff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{N}) where {T <: Number, N}
ε = Sophon.maybe_adapt(x, ε_)
return TaylorDiff.derivative(Base.Fix2(phi, θ), x, ε, Val{N+1}())
end

function Sophon.get_ε_h(::typeof(taylordiff), dim, der_num, fdtype, order)
function generate_ε(::typeof(taylordiff), dim, der_num, fdtype, order)
epsilon = one(fdtype)
ε = zeros(fdtype, dim)
ε[der_num] = epsilon
return ε, epsilon
return Sophon.SVector{dim}(ε)
end

for order in 1:4
for fdtype in (Float32, Float64)
@eval Sophon.get_h(::typeof(taylordiff), ::Type{$fdtype}, ::Val{$order}) = $(one(fdtype))
end
end

for l in 1:4
for d in 1:l
for order in 1:4
for fdtype in (Float32, Float64)
@eval const $(Symbol(:taylordiff_ε, :_, l, :_, d, :_, order, :_, fdtype)) =
$(generate_ε(taylordiff, l, d, fdtype, order))

@eval function Sophon.get_ε(::typeof(taylordiff), ::Val{$l}, ::Val{$d}, ::Type{$fdtype}, ::Val{$order})
return $(Symbol(:taylordiff_ε, :_, l, :_, d, :_, order, :_, fdtype))
end
end
end
end
end

# avoid NaN
function Base.:*(A::Union{Sophon.CuMatrix{T}, LinearAlgebra.Transpose{T, Sophon.CuArray}},
B::Sophon.CuMatrix{TaylorScalar{T, N}}) where {T, N}
C = similar(B, (size(A, 1), size(B, 2)))
fill!(C, zero(eltype(C)))
return LinearAlgebra.mul!(C, A, B)
end

function __init__()
Expand Down
2 changes: 2 additions & 0 deletions src/Sophon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import SciMLBase
import SciMLBase: parameterless_type, __solve, build_solution, NullParameters
using StatsBase, QuasiMonteCarlo
using Adapt, ChainRulesCore, CUDA, GPUArrays, GPUArraysCore
import GPUArraysCore: AbstractGPUArray
import QuasiMonteCarlo
import Sobol
using Distributions: Beta
Expand All @@ -28,6 +29,7 @@ using ForwardDiff
using MacroTools
using MacroTools: prewalk, postwalk
using Requires
using StaticArrays: SVector

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
21 changes: 10 additions & 11 deletions src/pde/sym_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,27 @@ const derivative_patterns = (

function transform_expression(pinnrep::NamedTuple{names}, ex::Expr) where {names}
(; indvars, dict_depvars, dict_depvar_input, fdtype, init_params, derivative) = pinnrep
use_gpu = isongpu(init_params)

# Step 1: Replace all the derivatives with the derivative function
ex = prewalk(ex) do x
quoted_x = Meta.quot(x)

for ((order1, order2), pattern) in reverse(mixed_derivative_patterns)
if @eval @capture($quoted_x, $pattern) && dr1 !== dr2
ε1, h1 = get_ε_h(derivative, length(args), findfirst(==(dr1), dict_depvar_input[ff]), fdtype, order1+order2)
ε2, h2 = get_ε_h(derivative, length(args), findfirst(==(dr2), dict_depvar_input[ff]), fdtype, order1+order2)
ε1 = use_gpu ? adapt(CuArray, ε1) : ε1
ε2 = use_gpu ? adapt(CuArray, ε2) : ε2

return :(derivative((x,ps)->derivative(phi_u, x, ps, $ε2, $h2, $(Val(order2))),
coord_u, θ, $ε1, $h1, $(Val(order1))))
order = Val(order1 + order2)
l = Val(length(args))
h = get_h(derivative, fdtype, order)
ε1 = get_ε(derivative, l, Val(findfirst(==(dr1), dict_depvar_input[ff])), fdtype, order)
ε2 = get_ε(derivative, l, Val(findfirst(==(dr2), dict_depvar_input[ff])), fdtype, order)

return :(derivative((x,ps)->derivative(phi_u, x, ps, $ε2, $h, $(Val(order2))),
coord_u, θ, $ε1, $h, $(Val(order1))))
end
end

for (order, pattern) in reverse(derivative_patterns)
if @eval @capture($quoted_x, $pattern)
ε, h = get_ε_h(derivative, length(args), findfirst(==(dr), dict_depvar_input[ff]), fdtype, order)
ε = use_gpu ? adapt(CuArray, ε) : ε
h = get_h(derivative, fdtype, Val(order))
ε = get_ε(derivative, Val(length(args)), Val(findfirst(==(dr), dict_depvar_input[ff])), fdtype, Val(order))
return :(derivative($(Symbol(:phi, :_, ff)), $(Symbol(:coord, :_, ff)), $(Symbol(, :_, ff)), $ε, $h, $(Val(order))))
end
end
Expand Down
64 changes: 50 additions & 14 deletions src/pde/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,35 @@ This function is only used for the first order derivative.
"""
forwarddiff(phi, t, εs, order, θ) = ForwardDiff.gradient(sum Base.Fix2(phi, θ), t)

@inline function finitediff(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
@inline maybe_adapt(x::AbstractGPUArray, ε_) = ChainRulesCore.@ignore_derivatives convert(CuArray, ε_)
@inline maybe_adapt(x, ε_) = ChainRulesCore.@ignore_derivatives ε_

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* (h / 2)
end

@inline function finitediff(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{2}) where {T<:AbstractFloat}
@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{2}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* h^2
end

@inline function finitediff(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{3}) where {T<:AbstractFloat}
@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{3}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) -
phi(x .- 2 .* ε, θ)) .* h^3 ./ 2
end

@inline function finitediff(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{4}) where {T<:AbstractFloat}
@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{4}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .-
4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* h^4
end

function finitediff(phi, x, θ, dim::Int, order::Int)
ε, h = ChainRulesCore.@ignore_derivatives get_ε_h(finitediff, size(x, 1), dim, eltype(θ), order)

ε = adapt(parameterless_type(ComponentArrays.getdata(θ)), ε)
ε = ChainRulesCore.@ignore_derivatives get_ε(finitediff, Val(size(x, 1)), Val(dim), eltype(θ), Val(order))
h = get_h(finitediff, eltype(x), Val(order))
ε = convert(parameterless_type(x), ε)

if order == 4
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .-
Expand All @@ -53,22 +60,51 @@ function finitediff(phi, x, θ, dim::Int, order::Int)
end

# only order = 1 is supported
function upwind(phi, x, θ, ε::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
function upwind(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
ε = ChainRulesCore.@ignore_derivatives convert(parameterless_type(x), ε_)
return (3 .* phi(x, θ) .- 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* (h / 2)
end

function get_ε_h(::typeof(finitediff), dim, der_num, fdtype, order)
epsilon = ^(eps(fdtype), one(fdtype) / (2 + order))

generate_epsilon(fdtype, order) = ^(eps(fdtype), one(fdtype) / (2 + order))

function generate_ε(::typeof(finitediff), dim, der_num, fdtype, order)
epsilon = generate_epsilon(fdtype, order)
ε = zeros(fdtype, dim)
ε[der_num] = epsilon
return ε, inv(epsilon)
return SVector{dim})
end

function get_ε_h(::typeof(upwind), dim, der_num, fdtype, order)
epsilon = ^(eps(fdtype), one(fdtype) / (2 + order))
function generate_ε(::typeof(upwind), dim, der_num, fdtype, order)
epsilon = generate_epsilon(fdtype, order)
ε = zeros(fdtype, dim)
ε[der_num] = epsilon
return ε, inv(epsilon)
return SVector{dim}(ε)
end

for order in 1:4
for fdtype in (Float32, Float64)
@eval const $(Symbol(:finitediff_h, :_, order, :_, fdtype)) =
$(inv(generate_epsilon(fdtype, order)))

@eval get_h(::typeof(finitediff), ::Type{$fdtype}, ::Val{$order}) =
$(Symbol(:finitediff_h, :_, order, :_, fdtype))
end
end

for l in 1:4
for d in 1:l
for order in 1:4
for fdtype in (Float32, Float64)
@eval const $(Symbol(:finitediff_ε, :_, l, :_, d, :_, order, :_, fdtype)) =
$(generate_ε(finitediff, l, d, fdtype, order))

@eval function get_ε(::typeof(finitediff), ::Val{$l}, ::Val{$d}, ::Type{$fdtype}, ::Val{$order})
return $(Symbol(:finitediff_ε, :_, l, :_, d, :_, order, :_, fdtype))
end
end
end
end
end

function Base.getproperty(d::Symbolics.VarDomainPairing, var::Symbol)
Expand Down
Loading

0 comments on commit 27f2e52

Please sign in to comment.