Skip to content

[WIP] GPU support through KernelAbstractions#51

Draft
lkdvos wants to merge 17 commits intomainfrom
ld-gpu
Draft

[WIP] GPU support through KernelAbstractions#51
lkdvos wants to merge 17 commits intomainfrom
ld-gpu

Conversation

@lkdvos
Copy link
Member

@lkdvos lkdvos commented Mar 16, 2026

This is an attempt to "solve" the GPU issues across all backends by implementing the _mapreduce_kernel for GPU stridedviews through KernelAbstractions.
Given that I have very little experience with this, and being helped heavily by AI, this should for now be considered as very WIP until I convince myself that all of this actually makes sense.

lkdvos and others added 10 commits March 15, 2026 08:48
Override _mapreduce_fuse! for GPU-backed StridedViews to dispatch
to a KernelAbstractions kernel instead of the CPU-specific threaded/SIMD
path. One GPU thread per output element with a sequential inner loop
over reduction dimensions. Handles pure map (op=nothing), reductions,
initop, and conj/adjoint views via ParentIndex semantics.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Tests cover: pure map!, reduction over dim 1, reduction over dim 2,
conj/adjoint StridedView, and full scalar reduction. JLArrays provides
a CPU-backed GPU simulator so tests run without real GPU hardware.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Three fixes:
- Add _mapreduce GPU override to avoid scalar indexing (first(A),
  out[ParentIndex(1)]) which JLArrays/real GPUs prohibit; uses zero(T)
  as proxy for type inference and similar(parent(A),...) to ensure the
  output stays on the GPU device
- Fix adjoint test expectation: copy!(adjoint(B), A) gives B = adjoint(A),
  not conj(A)
- Use qualified names Strided._init_reduction! and Strided._mapreducedim!
  since they are not exported into the extension module

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace zero(T) proxy with the same pattern GPUArrays uses:
infer the output element type via Broadcast.combine_eltypes +
Base.promote_op, then call GPUArrays.neutral_element(op, ET).
Unknown operators now produce a clear error message rather than
silently using zero(T). Also removes the dependency on the
unexported Strided._init_reduction!.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add 7 new testsets covering:
- map! reading from stride-2 input (every other row)
- map! writing into stride-2 output, checking untouched rows stay zero
- map! on a subview with nonzero offset (2:6, 3:6 slice)
- map! with permuted (transposed) strides via permutedims
- sum over dim 1 with stride-2 input
- sum over dim 2 with offset subview
- full scalar reduction on stride-2 and offset subviews

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Change Vararg{StridedView} to Vararg{StridedView{<:Any, N, <:AnyGPUArray}}
so the GPU kernel is only dispatched when every input (not just the output)
is GPU-backed. Mixed CPU/GPU calls fall through to the CPU path.

Add a test confirming the GPU path is bypassed for mixed inputs: the CPU
fallback's scalar GPU indexing guard fires, proving the GPU kernel was
not called.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Define const GPUStridedView{T,N} = StridedView{T, N, <:AnyGPUArray{T}}
and use it throughout the extension in place of the long-form
StridedView{T, N, <:AnyGPUArray{T}} / StridedView{<:Any, N, <:AnyGPUArray}
annotations on get_backend, BroadcastStyle, __mul!, _mapreduce, and
_mapreduce_fuse!.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions
Copy link

github-actions bot commented Mar 16, 2026

Your PR no longer requires formatting changes. Thank you for your contribution!

@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 1.23457% with 80 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/StridedGPUArraysExt.jl 1.53% 64 Missing ⚠️
src/broadcast.jl 0.00% 11 Missing ⚠️
src/linalg.jl 0.00% 5 Missing ⚠️
Files with missing lines Coverage Δ
src/linalg.jl 0.00% <0.00%> (ø)
src/broadcast.jl 0.00% <0.00%> (ø)
ext/StridedGPUArraysExt.jl 16.25% <1.53%> (-55.98%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return Array(out)[1]
end

function Strided._mapreduce_fuse!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the generic _mapreduce_fuse! step is still valid for GPUStridedView objects, so maybe _mapreduce_order! is where the lowering could be intercepted for GPUStridedView?

Comment on lines +143 to +147
out_total = prod(
ntuple(Val(N)) do d
@inbounds iszero(out.strides[d]) ? 1 : dims[d]
end
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think size(out) is still valid. Also, why do you want to go via linear indexing if KernelAbstractions supports Cartesian or Tuple indices?

@kshyatt
Copy link
Member

kshyatt commented Mar 17, 2026

I guess conceptually what's confusing me here is why not use the existing GPUArrays mapreducedim! to avoid having to rewrite all the kernels? Is there a reason this approach cannot work with the GPUStridedView?

@lkdvos
Copy link
Member Author

lkdvos commented Mar 17, 2026

I agree that it could, but I don't think we want to bypass all the Strided machinery that aims to fuse and reorder strides to improve data locality/block sizes?
From what I understood that machinery also really does go through linear indexing, and I don't think that is what we want here.

In any case I have to admit that I'm also just trying out some stuff, as mentioned in the description I am pulling inspiration from the different GPU packages + strided + claude to attempt to get something working here, and it clearly still needs some work.

@kshyatt
Copy link
Member

kshyatt commented Mar 17, 2026

My other high level take is that where possible, we should use the high performance vendor libraries (e.g. batched or regular matmatmul). The built-in GPUArrays kernel for this is quite bad...

@Jutho
Copy link
Member

Jutho commented Mar 17, 2026

If I find some time and @lkdvos agrees, I might want to try to play with and modify this PR a bit.

@lkdvos
Copy link
Member Author

lkdvos commented Mar 17, 2026

Feel free to, this was mostly just me playing around and trying to come up with a somewhat more permanent solution that would also work on my mac, even if it's not the fastest

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants