diff --git a/src/array.jl b/src/array.jl index fe00a724f..2e103916f 100644 --- a/src/array.jl +++ b/src/array.jl @@ -446,3 +446,43 @@ end function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=false) where T return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own) end + +## resizing + +""" + resize!(a::MtlVector, n::Integer) + +Resize `a` to contain `n` elements. If `n` is smaller than the current collection length, +the first `n` elements will be retained. If `n` is larger, the new elements are not +guaranteed to be initialized. +""" +function Base.resize!(A::MtlVector{T}, n::Integer) where T + # TODO: add additional space to allow for quicker resizing + maxsize = n * sizeof(T) + bufsize = if isbitstype(T) + maxsize + else + # type tag array past the data + maxsize + n + end + + # replace the data with a new one. this 'unshares' the array. + # as a result, we can safely support resizing unowned buffers. + buf = alloc(device(A), bufsize; storage=storagemode(A)) + ptr = MtlPointer{T}(buf) + m = min(length(A), n) + if m > 0 + unsafe_copyto!(device(A), ptr, pointer(A), m) + end + new_data = DataRef(buf) do buf + free(buf) + end + unsafe_free!(A) + + A.data = new_data + A.dims = (n,) + A.maxsize = maxsize + A.offset = 0 + + A +end diff --git a/test/array.jl b/test/array.jl index 4dd67b08f..4756932c7 100644 --- a/test/array.jl +++ b/test/array.jl @@ -230,4 +230,25 @@ end @test Metal.storagemode(e) == Shared end +@testset "resizing" begin + a = MtlArray([1,2,3]) + + resize!(a, 3) + @test length(a) == 3 + @test Array(a) == [1,2,3] + + resize!(a, 5) + @test length(a) == 5 + @test Array(a)[1:3] == [1,2,3] + + resize!(a, 2) + @test length(a) == 2 + @test Array(a)[1:2] == [1,2] + + b = MtlArray{Int}(undef, 0) + @test length(b) == 0 + resize!(b, 1) + @test length(b) == 1 +end + end