diff --git a/Project.toml b/Project.toml index 83cf4b8..2831685 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8" CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d" +DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" [weakdeps] @@ -21,6 +22,7 @@ CUDAExt = "CUDA" julia = "1.11" CUDA_Compiler_jll = "0.4" CUDA_Tile_jll = "13.1" +DLFP8Types = "0.1" [workspace] projects = ["test", "IRStructurizer", "FileCheck"] diff --git a/src/bytecode/types.jl b/src/bytecode/types.jl index 1b51769..f7e1f99 100644 --- a/src/bytecode/types.jl +++ b/src/bytecode/types.jl @@ -127,6 +127,8 @@ BF16(table::TypeTable) = simple_type!(table, SimpleType.BF16) F32(table::TypeTable) = simple_type!(table, SimpleType.F32) TF32(table::TypeTable) = simple_type!(table, SimpleType.TF32) F64(table::TypeTable) = simple_type!(table, SimpleType.F64) +F8E4M3FN(table::TypeTable) = simple_type!(table, SimpleType.F8E4M3FN) +F8E5M2(table::TypeTable) = simple_type!(table, SimpleType.F8E5M2) Token(table::TypeTable) = simple_type!(table, SimpleType.Token) function tile_type!(table::TypeTable, dtype::TypeId, shape::AbstractVector{<:Integer}) @@ -195,6 +197,10 @@ function julia_to_tile_dtype!(table::TypeTable, ::Type{T}) where T TF32(table) elseif T === Float64 F64(table) + elseif T === Float8_E4M3FN + F8E4M3FN(table) + elseif T === Float8_E5M2 + F8E5M2(table) elseif T <: Ptr elem_dtype = julia_to_tile_dtype!(table, eltype(T)) pointer_type!(table, elem_dtype) diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index eb87585..a0005f7 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -290,13 +290,22 @@ function float_to_bits(value::Float64, ::Type{T}) where T reinterpret(UInt32, Float32(value)) end +# Float8 types (from DLFP8Types) +function float_to_bits(value::Float64, ::Type{Float8_E4M3FN}) + reinterpret(UInt8, Float8_E4M3FN(value)) +end + +function float_to_bits(value::Float64, ::Type{Float8_E5M2}) + reinterpret(UInt8, Float8_E5M2(value)) +end + """ encode_signed_varint!(buf, value) Encode a signed integer as a variable-length integer. Uses zigzag encoding for signed values. """ -function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64}) +function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt8, UInt16, UInt32, UInt64, Int64}) # For float bits, encode as unsigned varint encode_varint!(buf, UInt64(value)) end @@ -544,7 +553,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder, end #============================================================================= - Optimization Hints + Optimization Hints =============================================================================# """ diff --git a/src/compiler/codegen/values.jl b/src/compiler/codegen/values.jl index f79839d..b44a83e 100644 --- a/src/compiler/codegen/values.jl +++ b/src/compiler/codegen/values.jl @@ -146,6 +146,8 @@ function constant_to_bytes(@nospecialize(value), @nospecialize(T::Type)) return collect(reinterpret(UInt8, [Int32(value)])) elseif T === Int64 || T === UInt64 return collect(reinterpret(UInt8, [Int64(value)])) + elseif T === Float16 + return collect(reinterpret(UInt8, [Float16(value)])) elseif T === Float32 return collect(reinterpret(UInt8, [Float32(value)])) elseif T === Float64 diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..c03602f 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -13,6 +13,9 @@ const CC = Core.Compiler using CUDA_Tile_jll +using DLFP8Types: Float8_E4M3FN, Float8_E5M2 +public Float8_E4M3FN, Float8_E5M2 + # Bytecode infrastructure include("bytecode/basic.jl") include("bytecode/types.jl") diff --git a/src/language/types.jl b/src/language/types.jl index 813f446..ddd6c40 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -309,8 +309,8 @@ primitive type TFloat32 <: AbstractFloat 32 end """Scalar integer types supported by Tile IR (i8, i16, i32, i64).""" const ScalarInt = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64} -"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64).""" -const ScalarFloat = Union{Float16, Float32, Float64, TFloat32} +"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64, f8e4m3fn, f8e5m2).""" +const ScalarFloat = Union{Float16, TFloat32, Float32, Float64, Float8_E4M3FN, Float8_E5M2} """Integer tile types.""" const TileInt{S} = Tile{T, S} where {T <: ScalarInt} diff --git a/test/codegen.jl b/test/codegen.jl index ae4b42e..028755a 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -443,6 +443,32 @@ return end end + + # Float32 -> Float8_E4M3FN + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{ct.Float8_E4M3FN}, tile) + ct.store(b, pid, ct.astype(converted, Float32)) + return + end + end + + # Float32 -> Float8_E5M2 + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{ct.Float8_E5M2}, tile) + ct.store(b, pid, ct.astype(converted, Float32)) + return + end + end end end diff --git a/test/execution.jl b/test/execution.jl index 8297a9d..5ce46ab 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -788,6 +788,58 @@ end @test Array(c) ≈ Array(a) + Array(b) end +@testset "Float8" begin + + @testset "Float8_E4M3FN" begin + # Float8 addition not supported + function vadd_f8e4m3fn(a::ct.TileArray{ct.Float8_E4M3FN,1}, b::ct.TileArray{ct.Float8_E4M3FN,1}, + c::ct.TileArray{ct.Float8_E4M3FN,1}) + pid = ct.bid(1) + tile_a = ct.astype(ct.load(a, pid, (16,)), Float16) + tile_b = ct.astype(ct.load(b, pid, (16,)), Float16) + tile_c = ct.astype(tile_a + tile_b, ct.Float8_E4M3FN) + ct.store(c, pid, tile_c) + return + end + + n = 1024 + tile_size = 16 + T = ct.Float8_E4M3FN + a = T.(CUDA.rand(n)) + b = T.(CUDA.rand(n)) + c = CUDA.zeros(T, n) + + ct.launch(vadd_f8e4m3fn, cld(n, tile_size), a, b, c) + + @test Array(c) ≈ Array(a) + Array(b) + end + + @testset "Float8_E5M2" begin + # Float8 addition not supported + function vadd_f8e5m2(a::ct.TileArray{ct.Float8_E5M2,1}, b::ct.TileArray{ct.Float8_E5M2,1}, + c::ct.TileArray{ct.Float8_E5M2,1}) + pid = ct.bid(1) + tile_a = ct.astype(ct.load(a, pid, (16,)), Float16) + tile_b = ct.astype(ct.load(b, pid, (16,)), Float16) + tile_c = ct.astype(tile_a + tile_b, ct.Float8_E5M2) + ct.store(c, pid, tile_c) + return + end + + n = 1024 + tile_size = 16 + T = ct.Float8_E5M2 + a = T.(CUDA.rand(n)) + b = T.(CUDA.rand(n)) + c = CUDA.zeros(T, n) + + ct.launch(vadd_f8e5m2, cld(n, tile_size), a, b, c) + + @test Array(c) ≈ Array(a) + Array(b) + end + +end + end @testset "compilation cache" begin