diff --git a/lib/mtl/buffer.jl b/lib/mtl/buffer.jl index cb3435b33..66b391f52 100644 --- a/lib/mtl/buffer.jl +++ b/lib/mtl/buffer.jl @@ -34,17 +34,36 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer; end function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr; - storage=Managed, hazard_tracking=DefaultTracking, + nocopy=false, storage=Shared, hazard_tracking=DefaultTracking, cache_mode=DefaultCPUCache) - storage == Private && error("Can't create a Private copy-allocated buffer.") + storage == Private && error("Cannot allocate-and-initialize a Private buffer") opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode @assert 0 < bytesize <= dev.maxBufferLength - ptr = alloc_buffer(dev, bytesize, opts, ptr) + ptr = if nocopy + alloc_buffer_nocopy(dev, bytesize, opts, ptr) + else + alloc_buffer(dev, bytesize, opts, ptr) + end return MTLBuffer(ptr) end +const PAGESIZE = ccall(:getpagesize, Cint, ()) +function can_alloc_nocopy(ptr::Ptr, bytesize::Integer) + # newBufferWithBytesNoCopy has several restrictions: + ## the pointer has to be page-aligned + if Int64(ptr) % PAGESIZE != 0 + return false + end + ## the new buffer needs to be page-aligned + ## XXX: on macOS 14, this doesn't seem required; is this a documentation issue? + if bytesize % PAGESIZE != 0 + return false + end + return true +end + # from device alloc_buffer(dev::MTLDevice, bytesize, opts) = @objc [dev::id{MTLDevice} newBufferWithLength:bytesize::NSUInteger @@ -53,6 +72,14 @@ alloc_buffer(dev::MTLDevice, bytesize, opts, ptr::Ptr) = @objc [dev::id{MTLDevice} newBufferWithBytes:ptr::Ptr{Cvoid} length:bytesize::NSUInteger options:opts::MTLResourceOptions]::id{MTLBuffer} +function alloc_buffer_nocopy(dev::MTLDevice, bytesize, opts, ptr::Ptr) + can_alloc_nocopy(ptr, bytesize) || + throw(ArgumentError("Cannot allocate nocopy buffer from non-aligned memory")) + @objc [dev::id{MTLDevice} newBufferWithBytesNoCopy:ptr::Ptr{Cvoid} + length:bytesize::NSUInteger + options:opts::MTLResourceOptions + deallocator:nil::id{Object}]::id{MTLBuffer} +end # from heap alloc_buffer(dev::MTLHeap, bytesize, opts) = diff --git a/src/array.jl b/src/array.jl index 0158b7a58..b34cb5f2e 100644 --- a/src/array.jl +++ b/src/array.jl @@ -91,6 +91,14 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N} end end +# Create MtlArray from MTLBuffer +function MtlArray{T,N}(buf::B, dims::Dims{N}; kwargs...) where {B<:MTLBuffer,T,N} + data = DataRef(buf) do buf + free(buf) + end + return MtlArray{T,N}(data, dims; kwargs...) +end + unsafe_free!(a::MtlArray) = GPUArrays.unsafe_free!(a.data) device(A::MtlArray) = A.data[].device @@ -491,6 +499,14 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false) return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own) end +function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, arr::Array, dims=size(arr); + dev=current_device(), kwargs...) where {T,N} + GC.@preserve arr begin + buf = MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...) + return A(buf, Dims(dims)) + end +end + ## resizing """ diff --git a/src/memory.jl b/src/memory.jl index 031b71e85..74bce49f8 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -35,11 +35,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, src::Ptr{T}, N::Int storage_type = dst.buffer.storageMode if storage_type == MTL.MTLStorageModePrivate # stage through a shared buffer - # shared = alloc(dev, N*sizeof(T), src; storage=Shared) - # unsafe_copyto!(dev, dst, pointer(shared), N; queue, async=false) - # free(shared) - tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared) #CPU -> GPU (Shared) - unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private) + nocopy = MTL.can_alloc_nocopy(src, N*sizeof(T)) + tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared, nocopy) + + # copy to the private buffer + unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N; + queue, async=(nocopy && async)) free(tmp_buf) elseif storage_type == MTL.MTLStorageModeShared unsafe_copyto!(convert(Ptr{T}, dst), src, N) @@ -54,12 +55,22 @@ end function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPtr{T}, N::Integer; queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T storage_type = src.buffer.storageMode - if storage_type == MTL.MTLStorageModePrivate + if storage_type == MTL.MTLStorageModePrivate # stage through a shared buffer - shared = alloc(dev, N*sizeof(T); storage=Shared) - unsafe_copyto!(dev, MtlPtr{T}(shared, 0), MtlPtr{T}(src.buffer, src.offset), N; queue, async=false) - unsafe_copyto!(dst, convert(Ptr{T}, shared), N) - free(shared) + nocopy = MTL.can_alloc_nocopy(dst, N*sizeof(T)) + tmp_buf = if nocopy + alloc(dev, N*sizeof(T), dst; storage=Shared, nocopy) + else + alloc(dev, N*sizeof(T); storage=Shared) + end + unsafe_copyto!(dev, MtlPtr{T}(tmp_buf, 0), MtlPtr{T}(src.buffer, src.offset), N; + queue, async=(nocopy && async)) + + # copy from the shared buffer + if !nocopy + unsafe_copyto!(dst, convert(Ptr{T}, tmp_buf), N) + end + free(tmp_buf) elseif storage_type == MTL.MTLStorageModeShared unsafe_copyto!(dst, convert(Ptr{T}, src), N) elseif storage_type == MTL.MTLStorageModeManaged diff --git a/test/array.jl b/test/array.jl index fb51ad5c2..272cdc50b 100644 --- a/test/array.jl +++ b/test/array.jl @@ -305,4 +305,52 @@ end @test length(b) == 1 end +function _alignedvec(::Type{T}, n::Integer, alignment::Integer=16384) where {T} + ispow2(alignment) || throw(ArgumentError("$alignment is not a power of 2")) + alignment ≥ sizeof(Int) || throw(ArgumentError("$alignment is not a multiple of $(sizeof(Int))")) + isbitstype(T) || throw(ArgumentError("$T is not a bitstype")) + p = Ref{Ptr{T}}() + err = ccall(:posix_memalign, Cint, (Ref{Ptr{T}}, Csize_t, Csize_t), p, alignment, n*sizeof(T)) + iszero(err) || throw(OutOfMemoryError()) + return unsafe_wrap(Array, p[], n, own=true) +end + +@testset "unsafe_wrap" begin + # Create page-aligned vector for testing + arr1 = _alignedvec(Float32, 16384*2); + fill!(arr1, zero(eltype(arr1))) + marr1 = unsafe_wrap(MtlVector{Float32}, arr1); + + @test all(arr1 .== 0) + @test all(marr1 .== 0) + + # XXX: Test fails when ordered as shown + # @test all(arr1 .== 1) + # @test all(marr1 .== 1) + marr1 .+= 1; + @test all(marr1 .== 1) + @test all(arr1 .== 1) + + arr1 .+= 1; + @test all(marr1 .== 2) + @test all(arr1 .== 2) + + marr2 = Metal.zeros(Float32, 18000; storage=Shared); + arr2 = unsafe_wrap(Vector{Float32}, marr2); + + @test all(arr2 .== 0) + @test all(marr2 .== 0) + + # XXX: Test fails when ordered as shown + # @test all(arr2 .== 1) + # @test all(marr2 .== 1) + marr2 .+= 1; + @test all(marr2 .== 1) + @test all(arr2 .== 1) + + arr2 .+= 1; + @test all(arr2 .== 2) + @test all(marr2 .== 2) +end + end