Skip to content

Commit

Permalink
Init working GPU_Plan based on package extension
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Jun 27, 2024
1 parent 9cf01cd commit c03b9fa
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Adapt = "3, 4"
AbstractNFFTs = "0.8"
BasicInterpolators = "0.6.5, 0.7"
DataFrames = "1.3.1, 1.4.1"
FFTW = "1.5"
FINUFFT = "3.0.1"
FLoops = "0.2"
GPUArrays = "8, 9, 10"
Reexport = "1.0"
PrecompileTools = "1"
SpecialFunctions = "0.8, 0.10, 1, 2"
Expand All @@ -42,7 +44,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Ducc0 = "47ec601d-2729-4ac9-bed9-2b3ab5fca9ff"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[targets]
test = ["Test", "BenchmarkTools", "FINUFFT", "NFFT3", "CuNFFT", "Zygote",
"NFFTTools", "DataFrames", "Ducc0"] # "NFFTTools" "CuNFFT"

[extensions]
NFFTGPUArraysExt = ["Adapt", "GPUArrays"]
9 changes: 9 additions & 0 deletions ext/NFFTGPUArraysExt/NFFTGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NFFTGPUArraysExt

using NFFT, NFFT.AbstractNFFTs
using NFFT.SparseArrays, NFFT.LinearAlgebra, NFFT.FFTW
using GPUArrays, Adapt

include("implementation.jl")

end
126 changes: 126 additions & 0 deletions ext/NFFTGPUArraysExt/implementation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI <: AbstractGPUVector{Int32}, FP, BP, SM} <: AbstractNFFTPlan{T,D,1}
N::NTuple{D,Int64}
NOut::NTuple{1,Int64}
J::Int64
k::Matrix{T}
::NTuple{D,Int64}
dims::UnitRange{Int64}
params::NFFTParams{T}
forwardFFT::FP
backwardFFT::BP
tmpVec::arrTc
tmpVecHat::arrTc
deconvolveIdx::vecI
windowLinInterp::Vector{T}
windowHatInvLUT::arrTc
B::SM
end

function AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...;

Check warning on line 19 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L19

Added line #L19 was not covered by tests
timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D}
t = @elapsed begin
p = GPU_NFFTPlan(arr, k, N, rest...; kargs...)

Check warning on line 22 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end
if timing != nothing
timing.pre = t

Check warning on line 25 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end
return p

Check warning on line 27 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L27

Added line #L27 was not covered by tests
end

function GPU_NFFTPlan(arr, k::Matrix{T}, N::NTuple{D,Int}; dims::Union{Integer,UnitRange{Int64}}=1:D,

Check warning on line 30 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L30

Added line #L30 was not covered by tests
fftflags=nothing, kwargs...) where {T,D}

if dims != 1:D
error("GPU NFFT does not work along directions right now!")

Check warning on line 34 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)
params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
params.precompute = NFFT.FULL # GPU_NFFT only works this way

Check warning on line 39 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L37-L39

Added lines #L37 - L39 were not covered by tests

tmpVec = adapt(arr, zeros(Complex{T}, Ñ))

Check warning on line 41 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L41

Added line #L41 was not covered by tests

FP = plan_fft!(tmpVec, dims_)
BP = plan_bfft!(tmpVec, dims_)

Check warning on line 44 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L43-L44

Added lines #L43 - L44 were not covered by tests

windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(k, N[dims_], Ñ[dims_], params)

Check warning on line 46 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L46

Added line #L46 was not covered by tests

U = params.storeDeconvolutionIdx ? N : ntuple(d->0,D)
tmpVecHat = adapt(arr, zeros(Complex{T}, U))

Check warning on line 49 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L48-L49

Added lines #L48 - L49 were not covered by tests

deconvIdx = adapt(arr, Int32.(deconvolveIdx))
winHatInvLUT = adapt(arr, windowHatInvLUT[1])
B_ = adapt(arr, Complex{T}.(B)) # Bit hacky

Check warning on line 53 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L51-L53

Added lines #L51 - L53 were not covered by tests

GPU_NFFTPlan{T,D, typeof(tmpVec), typeof(deconvIdx), typeof(FP), typeof(BP), typeof(B_)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat,

Check warning on line 55 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L55

Added line #L55 was not covered by tests
deconvIdx, windowLinInterp, winHatInvLUT, B_)
end

AbstractNFFTs.size_in(p::GPU_NFFTPlan) = p.N
AbstractNFFTs.size_out(p::GPU_NFFTPlan) = p.NOut

Check warning on line 60 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L59-L60

Added lines #L59 - L60 were not covered by tests


function AbstractNFFTs.convolve!(p::GPU_NFFTPlan{T,D, arrTc}, g::arrTc, fHat::arr) where {D,T,arr<: AbstractGPUArray, arrTc <: arr}
mul!(fHat, transpose(p.B), vec(g))
return

Check warning on line 65 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
end

function AbstractNFFTs.convolve_transpose!(p::GPU_NFFTPlan{T,D, arrTc}, fHat::arr, g::arrTc) where {D,T,arr<: AbstractGPUArray, arrTc <: arr}
mul!(vec(g), p.B, fHat)
return

Check warning on line 70 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L68-L70

Added lines #L68 - L70 were not covered by tests
end

function AbstractNFFTs.deconvolve!(p::GPU_NFFTPlan{T,D, arrTc}, f::arr, g::arrTc) where {D,T,arr<: AbstractGPUArray, arrTc <: arr}
p.tmpVecHat[:] .= vec(f) .* p.windowHatInvLUT
g[p.deconvolveIdx] = p.tmpVecHat
return

Check warning on line 76 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L73-L76

Added lines #L73 - L76 were not covered by tests
end

function AbstractNFFTs.deconvolve_transpose!(p::GPU_NFFTPlan{T,D, arrTc}, g::arrTc, f::arr) where {D,T,arr<: AbstractGPUArray, arrTc <: arr}
p.tmpVecHat[:] = g[p.deconvolveIdx]
f[:] .= vec(p.tmpVecHat) .* p.windowHatInvLUT
return

Check warning on line 82 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L79-L82

Added lines #L79 - L82 were not covered by tests
end

""" in-place NFFT on the GPU"""
function LinearAlgebra.mul!(fHat::arr, p::GPU_NFFTPlan{T,D, arrT}, f::arr;

Check warning on line 86 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L86

Added line #L86 was not covered by tests
verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D,arr<: AbstractGPUArray, arrT <: arr}
NFFT.consistencyCheck(p, f, fHat)

Check warning on line 88 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L88

Added line #L88 was not covered by tests

fill!(p.tmpVec, zero(Complex{T}))
t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
t2 = @elapsed p.forwardFFT * p.tmpVec
t3 = @elapsed @inbounds convolve!(p, p.tmpVec, fHat)
if verbose
@info "Timing: deconv=$t1 fft=$t2 conv=$t3"

Check warning on line 95 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L90-L95

Added lines #L90 - L95 were not covered by tests
end
if timing != nothing
timing.conv = t3
timing.fft = t2
timing.deconv = t1

Check warning on line 100 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L97-L100

Added lines #L97 - L100 were not covered by tests
end

return fHat

Check warning on line 103 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L103

Added line #L103 was not covered by tests
end

""" in-place adjoint NFFT on the GPU"""
function LinearAlgebra.mul!(f::arr, pl::Adjoint{Complex{T},<:GPU_NFFTPlan{T,D, arrT}}, fHat::arr;

Check warning on line 107 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L107

Added line #L107 was not covered by tests
verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D,arr<: AbstractGPUArray, arrT <: arr}
p = pl.parent
NFFT.consistencyCheck(p, f, fHat)

Check warning on line 110 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L109-L110

Added lines #L109 - L110 were not covered by tests

t1 = @elapsed @inbounds convolve_transpose!(p, fHat, p.tmpVec)
t2 = @elapsed p.backwardFFT * p.tmpVec
t3 = @elapsed @inbounds deconvolve_transpose!(p, p.tmpVec, f)
if verbose
@info "Timing: conv=$t1 fft=$t2 deconv=$t3"

Check warning on line 116 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L112-L116

Added lines #L112 - L116 were not covered by tests
end
if timing != nothing
timing.conv_adjoint = t1
timing.fft_adjoint = t2
timing.deconv_adjoint = t3

Check warning on line 121 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L118-L121

Added lines #L118 - L121 were not covered by tests
end

return f

Check warning on line 124 in ext/NFFTGPUArraysExt/implementation.jl

View check run for this annotation

Codecov / codecov/patch

ext/NFFTGPUArraysExt/implementation.jl#L124

Added line #L124 was not covered by tests
end

0 comments on commit c03b9fa

Please sign in to comment.