Skip to content

Commit

Permalink
Allow creating MtlArray that shares memory with Array
Browse files Browse the repository at this point in the history
GC.@preserve

Better? interface

Add tests

Improve error message
  • Loading branch information
christiangnrd committed Mar 20, 2024
1 parent bb33fa5 commit 240cc92
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
14 changes: 11 additions & 3 deletions lib/mtl/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
end

function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer, ptr::Ptr;
storage=Managed, hazard_tracking=DefaultTracking,
nocopy = false, storage=Shared, hazard_tracking=DefaultTracking,
cache_mode=DefaultCPUCache)
storage == Private && error("Can't create a Private copy-allocated buffer.")
storage == Private && error(LazyString("Cannot create a Private ", (nocopy ? "buffer that shares memory with an Array" : "copy-allocated buffer.")))
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode

@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
ptr = alloc_buffer(dev, bytesize, opts, ptr)

alloc_f = nocopy ? alloc_buffer_nocopy : alloc_buffer

ptr = alloc_f(dev, bytesize, opts, ptr)

return MTLBuffer(ptr)
end
Expand All @@ -53,6 +56,11 @@ alloc_buffer(dev::MTLDevice, bytesize, opts, ptr::Ptr) =
@objc [dev::id{MTLDevice} newBufferWithBytes:ptr::Ptr{Cvoid}
length:bytesize::NSUInteger
options:opts::MTLResourceOptions]::id{MTLBuffer}
alloc_buffer_nocopy(dev::MTLDevice, bytesize, opts, ptr::Ptr) = # ptr MUST be page-aligned
@objc [dev::id{MTLDevice} newBufferWithBytesNoCopy:ptr::Ptr{Cvoid}
length:bytesize::NSUInteger
options:opts::MTLResourceOptions
deallocator:nil::id{Object}]::id{MTLBuffer}

# from heap
alloc_buffer(dev::MTLHeap, bytesize, opts) =
Expand Down
16 changes: 16 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
end
end

# Create MtlArray from MTLBuffer
function MtlArray{T,N}(buf::B, dims::Dims{N}; kwargs...) where {B<:MTLBuffer,T,N}
data = DataRef(buf) do buf
free(buf)
end
return MtlArray{T,N}(data, dims; kwargs...)
end

unsafe_free!(a::MtlArray) = GPUArrays.unsafe_free!(a.data)

device(A::MtlArray) = A.data[].device
Expand Down Expand Up @@ -491,6 +499,14 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=fal
return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own)
end

function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}},
arr::Array,
dims=size(arr);
dev=current_device(),
kwargs...) where {T,N}
return GC.@preserve arr A(MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...),Dims(dims))
end

## resizing

"""
Expand Down
48 changes: 48 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,52 @@ end
@test length(b) == 1
end

function _alignedvec(::Type{T}, n::Integer, alignment::Integer=16384) where {T}
ispow2(alignment) || throw(ArgumentError("$alignment is not a power of 2"))
alignment sizeof(Int) || throw(ArgumentError("$alignment is not a multiple of $(sizeof(Int))"))
isbitstype(T) || throw(ArgumentError("$T is not a bitstype"))
p = Ref{Ptr{T}}()
err = ccall(:posix_memalign, Cint, (Ref{Ptr{T}}, Csize_t, Csize_t), p, alignment, n*sizeof(T))
iszero(err) || throw(OutOfMemoryError())
return unsafe_wrap(Array, p[], n, own=true)
end

@testset "unsafe_wrap" begin
# Create page-aligned vector for testing
arr1 = _alignedvec(Float32, 18000);
fill!(arr1, zero(eltype(arr1)))
marr1 = unsafe_wrap(MtlVector{Float32}, arr1);

@test all(arr1 .== 0)
@test all(marr1 .== 0)

# XXX: Test fails when ordered as shown
# @test all(arr1 .== 1)
# @test all(marr1 .== 1)
marr1 .+= 1;
@test all(marr1 .== 1)
@test all(arr1 .== 1)

arr1 .+= 1;
@test all(marr1 .== 2)
@test all(arr1 .== 2)

marr2 = Metal.zeros(Float32, 18000; storage=Shared);
arr2 = unsafe_wrap(Vector{Float32}, marr2);

@test all(arr2 .== 0)
@test all(marr2 .== 0)

# XXX: Test fails when ordered as shown
# @test all(arr2 .== 1)
# @test all(marr2 .== 1)
marr2 .+= 1;
@test all(marr2 .== 1)
@test all(arr2 .== 1)

arr2 .+= 1;
@test all(arr2 .== 2)
@test all(marr2 .== 2)
end

end

0 comments on commit 240cc92

Please sign in to comment.