Skip to content

Commit

Permalink
adapt_storage-related improvements (#296)
Browse files Browse the repository at this point in the history
* Check for Int128 compatibility
* Remove unneeded adapt_storage definitions and add tests.
* Add forgotten adaptor from parameterizing storage mode and test
  • Loading branch information
christiangnrd authored Mar 1, 2024
1 parent 81f716c commit 356ee6c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
15 changes: 6 additions & 9 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ function check_eltype(T)
Base.allocatedinline(T) || error("MtlArray only supports element types that are stored inline")
Base.isbitsunion(T) && error("MtlArray does not yet support isbits-union arrays")
contains_eltype(T, Float64) && error("Metal does not support Float64 values, try using Float32 instead")
contains_eltype(T, Int128) && error("Metal does not support Int128 values, try using Int64 instead")
contains_eltype(T, UInt128) && error("Metal does not support UInt128 values, try using UInt64 instead")
end

"""
Expand Down Expand Up @@ -314,6 +316,8 @@ Adapt.adapt_storage(::Type{<:MtlArray{T}}, xs::AT) where {T, AT<:AbstractArray}
isbitstype(AT) ? xs : convert(MtlArray{T}, xs)
Adapt.adapt_storage(::Type{<:MtlArray{T, N}}, xs::AT) where {T, N, AT<:AbstractArray} =
isbitstype(AT) ? xs : convert(MtlArray{T,N}, xs)
Adapt.adapt_storage(::Type{<:MtlArray{T, N, S}}, xs::AT) where {T, N, S, AT<:AbstractArray} =
isbitstype(AT) ? xs : convert(MtlArray{T,N,S}, xs)


## opinionated gpu array adaptor
Expand All @@ -325,19 +329,12 @@ struct MtlArrayAdaptor{S} end
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T,N,S} =
isbits(xs) ? xs : MtlArray{T,N,S}(xs)

Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,S} =
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float64,N,S} =
isbits(xs) ? xs : MtlArray{Float32,N,S}(xs)

Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,S} =
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:Float64},N,S} =
isbits(xs) ? xs : MtlArray{ComplexF32,N,S}(xs)

# not for Float16
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float16,N,S} =
isbits(xs) ? xs : MtlArray{T,N,S}(xs)

Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{Float16},N,S} =
isbits(xs) ? xs : MtlArray{T,N,S}(xs)

"""
mtl(A; storage=Private)
Expand Down
37 changes: 32 additions & 5 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,40 @@ end
@test Base.elsize(xs) == sizeof(Int)
@test pointer(MtlArray{Int, 2}(xs)) != pointer(xs)

# test aggressive conversion to Float32, but only for floats, and only with `mtl`
@test mtl([1]) isa MtlArray{Int}
@test mtl(Float64[1]) isa MtlArray{Float32}
@test mtl(ComplexF64[1+1im]) isa MtlArray{ComplexF32}
@test mtl(ComplexF16[1+1im]) isa MtlArray{ComplexF16}
# Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
# Only bfloat missing
supported_number_types = [Float16 => Float16,
Float32 => Float32,
Float64 => Float32,
Bool => Bool,
Int16 => Int16,
Int32 => Int32,
Int64 => Int64,
Int8 => Int8,
UInt16 => UInt16,
UInt32 => UInt32,
UInt64 => UInt64,
UInt8 => UInt8]
# Test supported types and ensure only Float64 get converted to Float32
for (SrcType, TargType) in supported_number_types
@test mtl(SrcType[1]) isa MtlArray{TargType}
@test mtl(Complex{SrcType}[1+1im]) isa MtlArray{Complex{TargType}}
end

# test the regular adaptor
@test Adapt.adapt(MtlArray, [1 2;3 4]) isa MtlArray{Int, 2, Private}
@test Adapt.adapt(MtlArray{Float32}, [1 2;3 4]) isa MtlArray{Float32, 2, Private}
@test Adapt.adapt(MtlArray{Float32, 2}, [1 2;3 4]) isa MtlArray{Float32, 2, Private}
@test Adapt.adapt(MtlArray{Float32, 2, Shared}, [1 2;3 4]) isa MtlArray{Float32, 2, Shared}
@test Adapt.adapt(MtlMatrix{ComplexF32, Shared}, [1 2;3 4]) isa MtlArray{ComplexF32, 2, Shared}
@test Adapt.adapt(MtlArray{Float16}, Float64[1]) isa MtlArray{Float16}

# Test a few explicitly unsupported types
@test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigInt[1])
@test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigFloat[1])
@test_throws "Metal does not support Float64 values" MtlArray(Float64[1])
@test_throws "Metal does not support Int128 values" MtlArray(Int128[1])
@test_throws "Metal does not support UInt128 values" MtlArray(UInt128[1])

@test collect(Metal.zeros(2, 2)) == zeros(Float32, 2, 2)
@test collect(Metal.ones(2, 2)) == ones(Float32, 2, 2)
Expand Down

0 comments on commit 356ee6c

Please sign in to comment.