diff --git a/Project.toml b/Project.toml index 83cf4b8..fb27276 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Tim Besard "] version = "0.1.0" [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8" CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d" IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" @@ -19,6 +20,7 @@ CUDAExt = "CUDA" [compat] julia = "1.11" +BFloat16s = "0.6" CUDA_Compiler_jll = "0.4" CUDA_Tile_jll = "13.1" diff --git a/examples/matmul.jl b/examples/matmul.jl index 47a24e6..b222432 100644 --- a/examples/matmul.jl +++ b/examples/matmul.jl @@ -141,19 +141,19 @@ function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T println(" passed") end -function main() +function main(T=Float32) println("--- cuTile Matrix Multiplication Examples ---\n") - # Small matrices with Float32 - test_matmul(Float32, 256, 256, 256, 32, 32, 32) - test_matmul(Float32, 512, 512, 512, 64, 64, 64) + # Small matrices + test_matmul(T, 256, 256, 256, 32, 32, 32) + test_matmul(T, 512, 512, 512, 64, 64, 64) # Non-square matrices - test_matmul(Float32, 256, 512, 128, 32, 32, 32) - test_matmul(Float32, 512, 256, 384, 64, 64, 64) + test_matmul(T, 256, 512, 128, 32, 32, 32) + test_matmul(T, 512, 256, 384, 64, 64, 64) # Larger matrices - test_matmul(Float32, 1024, 1024, 1024, 32, 32, 32) + test_matmul(T, 1024, 1024, 1024, 32, 32, 32) println("\n--- All matmul examples completed ---") end diff --git a/src/bytecode/types.jl b/src/bytecode/types.jl index 1b51769..abf6393 100644 --- a/src/bytecode/types.jl +++ b/src/bytecode/types.jl @@ -189,6 +189,8 @@ function julia_to_tile_dtype!(table::TypeTable, ::Type{T}) where T I64(table) elseif T === Float16 F16(table) + elseif T === BFloat16 + BF16(table) elseif T === Float32 F32(table) elseif T === TFloat32 diff --git a/src/bytecode/writer.jl b/src/bytecode/writer.jl index eb87585..a1be530 100644 --- a/src/bytecode/writer.jl +++ b/src/bytecode/writer.jl @@ -276,6 +276,10 @@ function float_to_bits(value::Float64, ::Type{Float16}) reinterpret(UInt16, Float16(value)) end +function float_to_bits(value::Float64, ::Type{BFloat16}) + reinterpret(UInt16, BFloat16(value)) +end + function float_to_bits(value::Float64, ::Type{Float32}) reinterpret(UInt32, Float32(value)) end diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..2bea086 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -13,6 +13,9 @@ const CC = Core.Compiler using CUDA_Tile_jll +using BFloat16s: BFloat16 +public BFloat16 + # Bytecode infrastructure include("bytecode/basic.jl") include("bytecode/types.jl") diff --git a/src/language/types.jl b/src/language/types.jl index 813f446..c6459a7 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -309,8 +309,8 @@ primitive type TFloat32 <: AbstractFloat 32 end """Scalar integer types supported by Tile IR (i8, i16, i32, i64).""" const ScalarInt = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64} -"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64).""" -const ScalarFloat = Union{Float16, Float32, Float64, TFloat32} +"""Scalar floating-point types supported by Tile IR (f16, bf16, tf32, f32, f64).""" +const ScalarFloat = Union{Float16, BFloat16, Float32, Float64, TFloat32} """Integer tile types.""" const TileInt{S} = Tile{T, S} where {T <: ScalarInt} diff --git a/test/codegen.jl b/test/codegen.jl index ae4b42e..559f3d0 100644 --- a/test/codegen.jl +++ b/test/codegen.jl @@ -443,6 +443,19 @@ return end end + + # Float32 -> BFloat16 + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{ct.BFloat16}, tile) + ct.store(b, pid, ct.astype(converted, Float32)) + return + end + end end end diff --git a/test/execution.jl b/test/execution.jl index 8297a9d..1361442 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -788,6 +788,27 @@ end @test Array(c) ≈ Array(a) + Array(b) end +@testset "BFloat16" begin + function vadd_bf16(a::ct.TileArray{ct.BFloat16,1}, b::ct.TileArray{ct.BFloat16,1}, + c::ct.TileArray{ct.BFloat16,1}) + pid = ct.bid(1) + tile_a = ct.load(a, pid, (16,)) + tile_b = ct.load(b, pid, (16,)) + ct.store(c, pid, tile_a + tile_b) + return + end + + n = 1024 + tile_size = 16 + a = CUDA.rand(ct.BFloat16, n) + b = CUDA.rand(ct.BFloat16, n) + c = CUDA.zeros(ct.BFloat16, n) + + ct.launch(vadd_bf16, cld(n, tile_size), a, b, c) + + @test Array(c) ≈ Array(a) + Array(b) +end + end @testset "compilation cache" begin