Skip to content

Commit

Permalink
Add resize! (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 31, 2024
1 parent f6df13d commit d03644a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,43 @@ end
function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=false) where T
return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own)
end

## resizing

"""
resize!(a::MtlVector, n::Integer)
Resize `a` to contain `n` elements. If `n` is smaller than the current collection length,
the first `n` elements will be retained. If `n` is larger, the new elements are not
guaranteed to be initialized.
"""
function Base.resize!(A::MtlVector{T}, n::Integer) where T
# TODO: add additional space to allow for quicker resizing
maxsize = n * sizeof(T)
bufsize = if isbitstype(T)
maxsize
else
# type tag array past the data
maxsize + n
end

# replace the data with a new one. this 'unshares' the array.
# as a result, we can safely support resizing unowned buffers.
buf = alloc(device(A), bufsize; storage=storagemode(A))
ptr = MtlPointer{T}(buf)
m = min(length(A), n)
if m > 0
unsafe_copyto!(device(A), ptr, pointer(A), m)
end
new_data = DataRef(buf) do buf
free(buf)
end
unsafe_free!(A)

A.data = new_data
A.dims = (n,)
A.maxsize = maxsize
A.offset = 0

A
end
21 changes: 21 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,25 @@ end
@test Metal.storagemode(e) == Shared
end

@testset "resizing" begin
a = MtlArray([1,2,3])

resize!(a, 3)
@test length(a) == 3
@test Array(a) == [1,2,3]

resize!(a, 5)
@test length(a) == 5
@test Array(a)[1:3] == [1,2,3]

resize!(a, 2)
@test length(a) == 2
@test Array(a)[1:2] == [1,2]

b = MtlArray{Int}(undef, 0)
@test length(b) == 0
resize!(b, 1)
@test length(b) == 1
end

end

0 comments on commit d03644a

Please sign in to comment.