From 356ee6cd8fbe8a6bc1578be1ab49b20079c34824 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 1 Mar 2024 03:52:20 -0400 Subject: [PATCH] `adapt_storage`-related improvements (#296) * Check for Int128 compatibility * Remove unneeded adapt_storage definitions and add tests. * Add forgotten adaptor from parameterizing storage mode and test --- src/array.jl | 15 ++++++--------- test/array.jl | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/array.jl b/src/array.jl index 62a2ff507..7101f4dc5 100644 --- a/src/array.jl +++ b/src/array.jl @@ -30,6 +30,8 @@ function check_eltype(T) Base.allocatedinline(T) || error("MtlArray only supports element types that are stored inline") Base.isbitsunion(T) && error("MtlArray does not yet support isbits-union arrays") contains_eltype(T, Float64) && error("Metal does not support Float64 values, try using Float32 instead") + contains_eltype(T, Int128) && error("Metal does not support Int128 values, try using Int64 instead") + contains_eltype(T, UInt128) && error("Metal does not support UInt128 values, try using UInt64 instead") end """ @@ -314,6 +316,8 @@ Adapt.adapt_storage(::Type{<:MtlArray{T}}, xs::AT) where {T, AT<:AbstractArray} isbitstype(AT) ? xs : convert(MtlArray{T}, xs) Adapt.adapt_storage(::Type{<:MtlArray{T, N}}, xs::AT) where {T, N, AT<:AbstractArray} = isbitstype(AT) ? xs : convert(MtlArray{T,N}, xs) +Adapt.adapt_storage(::Type{<:MtlArray{T, N, S}}, xs::AT) where {T, N, S, AT<:AbstractArray} = + isbitstype(AT) ? xs : convert(MtlArray{T,N,S}, xs) ## opinionated gpu array adaptor @@ -325,19 +329,12 @@ struct MtlArrayAdaptor{S} end Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T,N,S} = isbits(xs) ? xs : MtlArray{T,N,S}(xs) -Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,S} = +Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float64,N,S} = isbits(xs) ? xs : MtlArray{Float32,N,S}(xs) -Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,S} = +Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:Float64},N,S} = isbits(xs) ? xs : MtlArray{ComplexF32,N,S}(xs) -# not for Float16 -Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float16,N,S} = - isbits(xs) ? xs : MtlArray{T,N,S}(xs) - -Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{Float16},N,S} = - isbits(xs) ? xs : MtlArray{T,N,S}(xs) - """ mtl(A; storage=Private) diff --git a/test/array.jl b/test/array.jl index d48c8ad07..5271c5e29 100644 --- a/test/array.jl +++ b/test/array.jl @@ -22,13 +22,40 @@ end @test Base.elsize(xs) == sizeof(Int) @test pointer(MtlArray{Int, 2}(xs)) != pointer(xs) - # test aggressive conversion to Float32, but only for floats, and only with `mtl` - @test mtl([1]) isa MtlArray{Int} - @test mtl(Float64[1]) isa MtlArray{Float32} - @test mtl(ComplexF64[1+1im]) isa MtlArray{ComplexF32} - @test mtl(ComplexF16[1+1im]) isa MtlArray{ComplexF16} + # Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + # Only bfloat missing + supported_number_types = [Float16 => Float16, + Float32 => Float32, + Float64 => Float32, + Bool => Bool, + Int16 => Int16, + Int32 => Int32, + Int64 => Int64, + Int8 => Int8, + UInt16 => UInt16, + UInt32 => UInt32, + UInt64 => UInt64, + UInt8 => UInt8] + # Test supported types and ensure only Float64 get converted to Float32 + for (SrcType, TargType) in supported_number_types + @test mtl(SrcType[1]) isa MtlArray{TargType} + @test mtl(Complex{SrcType}[1+1im]) isa MtlArray{Complex{TargType}} + end + + # test the regular adaptor + @test Adapt.adapt(MtlArray, [1 2;3 4]) isa MtlArray{Int, 2, Private} + @test Adapt.adapt(MtlArray{Float32}, [1 2;3 4]) isa MtlArray{Float32, 2, Private} + @test Adapt.adapt(MtlArray{Float32, 2}, [1 2;3 4]) isa MtlArray{Float32, 2, Private} + @test Adapt.adapt(MtlArray{Float32, 2, Shared}, [1 2;3 4]) isa MtlArray{Float32, 2, Shared} + @test Adapt.adapt(MtlMatrix{ComplexF32, Shared}, [1 2;3 4]) isa MtlArray{ComplexF32, 2, Shared} @test Adapt.adapt(MtlArray{Float16}, Float64[1]) isa MtlArray{Float16} + # Test a few explicitly unsupported types + @test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigInt[1]) + @test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigFloat[1]) + @test_throws "Metal does not support Float64 values" MtlArray(Float64[1]) + @test_throws "Metal does not support Int128 values" MtlArray(Int128[1]) + @test_throws "Metal does not support UInt128 values" MtlArray(UInt128[1]) @test collect(Metal.zeros(2, 2)) == zeros(Float32, 2, 2) @test collect(Metal.ones(2, 2)) == ones(Float32, 2, 2)