Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8"
CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d"
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"

[weakdeps]
Expand All @@ -21,6 +22,7 @@ CUDAExt = "CUDA"
julia = "1.11"
CUDA_Compiler_jll = "0.4"
CUDA_Tile_jll = "13.1"
DLFP8Types = "0.1"

[workspace]
projects = ["test", "IRStructurizer", "FileCheck"]
6 changes: 6 additions & 0 deletions src/bytecode/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ BF16(table::TypeTable) = simple_type!(table, SimpleType.BF16)
F32(table::TypeTable) = simple_type!(table, SimpleType.F32)
TF32(table::TypeTable) = simple_type!(table, SimpleType.TF32)
F64(table::TypeTable) = simple_type!(table, SimpleType.F64)
F8E4M3FN(table::TypeTable) = simple_type!(table, SimpleType.F8E4M3FN)
F8E5M2(table::TypeTable) = simple_type!(table, SimpleType.F8E5M2)
Token(table::TypeTable) = simple_type!(table, SimpleType.Token)

function tile_type!(table::TypeTable, dtype::TypeId, shape::AbstractVector{<:Integer})
Expand Down Expand Up @@ -195,6 +197,10 @@ function julia_to_tile_dtype!(table::TypeTable, ::Type{T}) where T
TF32(table)
elseif T === Float64
F64(table)
elseif T === Float8_E4M3FN
F8E4M3FN(table)
elseif T === Float8_E5M2
F8E5M2(table)
elseif T <: Ptr
elem_dtype = julia_to_tile_dtype!(table, eltype(T))
pointer_type!(table, elem_dtype)
Expand Down
13 changes: 11 additions & 2 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,22 @@ function float_to_bits(value::Float64, ::Type{T}) where T
reinterpret(UInt32, Float32(value))
end

# Float8 types (from DLFP8Types)
function float_to_bits(value::Float64, ::Type{Float8_E4M3FN})
reinterpret(UInt8, Float8_E4M3FN(value))
end

function float_to_bits(value::Float64, ::Type{Float8_E5M2})
reinterpret(UInt8, Float8_E5M2(value))
end

"""
encode_signed_varint!(buf, value)

Encode a signed integer as a variable-length integer.
Uses zigzag encoding for signed values.
"""
function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt16, UInt32, UInt64, Int64})
function encode_signed_varint!(buf::Vector{UInt8}, value::Union{UInt8, UInt16, UInt32, UInt64, Int64})
# For float bits, encode as unsigned varint
encode_varint!(buf, UInt64(value))
end
Expand Down Expand Up @@ -544,7 +553,7 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
end

#=============================================================================
Optimization Hints
Optimization Hints
=============================================================================#

"""
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/codegen/values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ function constant_to_bytes(@nospecialize(value), @nospecialize(T::Type))
return collect(reinterpret(UInt8, [Int32(value)]))
elseif T === Int64 || T === UInt64
return collect(reinterpret(UInt8, [Int64(value)]))
elseif T === Float16
return collect(reinterpret(UInt8, [Float16(value)]))
elseif T === Float32
return collect(reinterpret(UInt8, [Float32(value)]))
elseif T === Float64
Expand Down
3 changes: 3 additions & 0 deletions src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ const CC = Core.Compiler

using CUDA_Tile_jll

using DLFP8Types: Float8_E4M3FN, Float8_E5M2
public Float8_E4M3FN, Float8_E5M2

# Bytecode infrastructure
include("bytecode/basic.jl")
include("bytecode/types.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ primitive type TFloat32 <: AbstractFloat 32 end
"""Scalar integer types supported by Tile IR (i8, i16, i32, i64)."""
const ScalarInt = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}

"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64)."""
const ScalarFloat = Union{Float16, Float32, Float64, TFloat32}
"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64, f8e4m3fn, f8e5m2)."""
const ScalarFloat = Union{Float16, TFloat32, Float32, Float64, Float8_E4M3FN, Float8_E5M2}

"""Integer tile types."""
const TileInt{S} = Tile{T, S} where {T <: ScalarInt}
Expand Down
26 changes: 26 additions & 0 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,32 @@
return
end
end

# Float32 -> Float8_E4M3FN
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{ct.Float8_E4M3FN}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
return
end
end

# Float32 -> Float8_E5M2
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{ct.Float8_E5M2}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
return
end
end
end
end

Expand Down
52 changes: 52 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,58 @@ end
@test Array(c) ≈ Array(a) + Array(b)
end

@testset "Float8" begin

@testset "Float8_E4M3FN" begin
# Float8 addition not supported
function vadd_f8e4m3fn(a::ct.TileArray{ct.Float8_E4M3FN,1}, b::ct.TileArray{ct.Float8_E4M3FN,1},
c::ct.TileArray{ct.Float8_E4M3FN,1})
pid = ct.bid(1)
tile_a = ct.astype(ct.load(a, pid, (16,)), Float16)
tile_b = ct.astype(ct.load(b, pid, (16,)), Float16)
tile_c = ct.astype(tile_a + tile_b, ct.Float8_E4M3FN)
ct.store(c, pid, tile_c)
return
end

n = 1024
tile_size = 16
T = ct.Float8_E4M3FN
a = T.(CUDA.rand(n))
b = T.(CUDA.rand(n))
c = CUDA.zeros(T, n)

ct.launch(vadd_f8e4m3fn, cld(n, tile_size), a, b, c)

@test Array(c) ≈ Array(a) + Array(b)
end

@testset "Float8_E5M2" begin
# Float8 addition not supported
function vadd_f8e5m2(a::ct.TileArray{ct.Float8_E5M2,1}, b::ct.TileArray{ct.Float8_E5M2,1},
c::ct.TileArray{ct.Float8_E5M2,1})
pid = ct.bid(1)
tile_a = ct.astype(ct.load(a, pid, (16,)), Float16)
tile_b = ct.astype(ct.load(b, pid, (16,)), Float16)
tile_c = ct.astype(tile_a + tile_b, ct.Float8_E5M2)
ct.store(c, pid, tile_c)
return
end

n = 1024
tile_size = 16
T = ct.Float8_E5M2
a = T.(CUDA.rand(n))
b = T.(CUDA.rand(n))
c = CUDA.zeros(T, n)

ct.launch(vadd_f8e5m2, cld(n, tile_size), a, b, c)

@test Array(c) ≈ Array(a) + Array(b)
end

end

end

@testset "compilation cache" begin
Expand Down