-
Notifications
You must be signed in to change notification settings - Fork 15
Add some mul overrides too #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ module StridedGPUArraysExt | |
|
|
||
| using Strided, GPUArrays | ||
| using GPUArrays: Adapt, KernelAbstractions | ||
| using GPUArrays.KernelAbstractions: @kernel, @index | ||
|
|
||
| ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} | ||
|
|
||
|
|
@@ -19,4 +20,27 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS | |
| return dst | ||
| end | ||
|
|
||
| # lifted from GPUArrays.jl | ||
| function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS} | ||
| isempty(A) && return A | ||
| @kernel function fill_kernel!(a, val) | ||
| idx = @index(Global, Linear) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would using
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can do that instead, but it also affects the launch configuration for the kernel of course. I'll try that in a fix PR that hopefully will also fix TagBot... |
||
| @inbounds a[idx] = val | ||
| end | ||
| # ndims check for 0D support | ||
| kernel = fill_kernel!(KernelAbstractions.get_backend(A)) | ||
| f_x = F <: Union{typeof(conj), typeof(adjoint)} ? conj(x) : x | ||
| kernel(A, f_x; ndrange = length(A)) | ||
| return A | ||
| end | ||
|
|
||
| function Strided.__mul!( | ||
| C::StridedView{TC, 2, <:AnyGPUArray{TC}}, | ||
| A::StridedView{TA, 2, <:AnyGPUArray{TA}}, | ||
| B::StridedView{TB, 2, <:AnyGPUArray{TB}}, | ||
| α::Number, β::Number | ||
| ) where {TC, TA, TB} | ||
| return GPUArrays.generic_matmatmul!(C, A, B, α, β) | ||
| end | ||
|
|
||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm slightly hesitant about these kinds of overloads, but I think it is just ignorance:
From what I know, the point of Strided is to have some logic that makes it such that even though
Amight have strides that are not just the usual multidimensional ones, it tries to figure out a way to access the parent array "as linearly as possible".If I understand this overload correctly, this is simply doing a mapping and making use of the index computations, so in that sense it is somewhat similar to the implementation for just a regular
SubArray, I think?However, I don't actually know if the same kinds of performance ideas about linear access are even valid on the GPU, so this might be completely reasonable?
Maybe we could schedule a meeting somewhere next week to discuss this, and the possible implications of deciding what type of objects to use for our strided views?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I think that makes sense. I think coming up with a good access pattern is probably even more important on the GPU but here I've done the naive thing just to get it started, and then we should definitely do some benchmarking/profiling to see if this is really slowing us down or not.