From 240cc92fd19908123bb40261bdb02b0498beb5ae Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:11:24 -0300 Subject: [PATCH] Allow creating MtlArray that shares memory with Array GC.@preserve Better? interface Add tests Improve error message --- lib/mtl/buffer.jl | 14 +++++++++++--- src/array.jl | 16 ++++++++++++++++ test/array.jl | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/lib/mtl/buffer.jl b/lib/mtl/buffer.jl index b9d36fe70..206283796 100644 --- a/lib/mtl/buffer.jl +++ b/lib/mtl/buffer.jl @@ -34,13 +34,16 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer; end function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer, ptr::Ptr; - storage=Managed, hazard_tracking=DefaultTracking, + nocopy = false, storage=Shared, hazard_tracking=DefaultTracking, cache_mode=DefaultCPUCache) - storage == Private && error("Can't create a Private copy-allocated buffer.") + storage == Private && error(LazyString("Cannot create a Private ", (nocopy ? "buffer that shares memory with an Array" : "copy-allocated buffer."))) opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode @assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap - ptr = alloc_buffer(dev, bytesize, opts, ptr) + + alloc_f = nocopy ? alloc_buffer_nocopy : alloc_buffer + + ptr = alloc_f(dev, bytesize, opts, ptr) return MTLBuffer(ptr) end @@ -53,6 +56,11 @@ alloc_buffer(dev::MTLDevice, bytesize, opts, ptr::Ptr) = @objc [dev::id{MTLDevice} newBufferWithBytes:ptr::Ptr{Cvoid} length:bytesize::NSUInteger options:opts::MTLResourceOptions]::id{MTLBuffer} +alloc_buffer_nocopy(dev::MTLDevice, bytesize, opts, ptr::Ptr) = # ptr MUST be page-aligned + @objc [dev::id{MTLDevice} newBufferWithBytesNoCopy:ptr::Ptr{Cvoid} + length:bytesize::NSUInteger + options:opts::MTLResourceOptions + deallocator:nil::id{Object}]::id{MTLBuffer} # from heap alloc_buffer(dev::MTLHeap, bytesize, opts) = diff --git a/src/array.jl b/src/array.jl index 70c429a10..3e2eb7d52 100644 --- a/src/array.jl +++ b/src/array.jl @@ -91,6 +91,14 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N} end end +# Create MtlArray from MTLBuffer +function MtlArray{T,N}(buf::B, dims::Dims{N}; kwargs...) where {B<:MTLBuffer,T,N} + data = DataRef(buf) do buf + free(buf) + end + return MtlArray{T,N}(data, dims; kwargs...) +end + unsafe_free!(a::MtlArray) = GPUArrays.unsafe_free!(a.data) device(A::MtlArray) = A.data[].device @@ -491,6 +499,14 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=fal return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own) end +function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, + arr::Array, + dims=size(arr); + dev=current_device(), + kwargs...) where {T,N} + return GC.@preserve arr A(MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...),Dims(dims)) +end + ## resizing """ diff --git a/test/array.jl b/test/array.jl index 9db5e5083..66bf721dc 100644 --- a/test/array.jl +++ b/test/array.jl @@ -305,4 +305,52 @@ end @test length(b) == 1 end +function _alignedvec(::Type{T}, n::Integer, alignment::Integer=16384) where {T} + ispow2(alignment) || throw(ArgumentError("$alignment is not a power of 2")) + alignment ≥ sizeof(Int) || throw(ArgumentError("$alignment is not a multiple of $(sizeof(Int))")) + isbitstype(T) || throw(ArgumentError("$T is not a bitstype")) + p = Ref{Ptr{T}}() + err = ccall(:posix_memalign, Cint, (Ref{Ptr{T}}, Csize_t, Csize_t), p, alignment, n*sizeof(T)) + iszero(err) || throw(OutOfMemoryError()) + return unsafe_wrap(Array, p[], n, own=true) +end + +@testset "unsafe_wrap" begin + # Create page-aligned vector for testing + arr1 = _alignedvec(Float32, 18000); + fill!(arr1, zero(eltype(arr1))) + marr1 = unsafe_wrap(MtlVector{Float32}, arr1); + + @test all(arr1 .== 0) + @test all(marr1 .== 0) + + # XXX: Test fails when ordered as shown + # @test all(arr1 .== 1) + # @test all(marr1 .== 1) + marr1 .+= 1; + @test all(marr1 .== 1) + @test all(arr1 .== 1) + + arr1 .+= 1; + @test all(marr1 .== 2) + @test all(arr1 .== 2) + + marr2 = Metal.zeros(Float32, 18000; storage=Shared); + arr2 = unsafe_wrap(Vector{Float32}, marr2); + + @test all(arr2 .== 0) + @test all(marr2 .== 0) + + # XXX: Test fails when ordered as shown + # @test all(arr2 .== 1) + # @test all(marr2 .== 1) + marr2 .+= 1; + @test all(marr2 .== 1) + @test all(arr2 .== 1) + + arr2 .+= 1; + @test all(arr2 .== 2) + @test all(marr2 .== 2) +end + end