Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Tim Besard <tim.besard@gmail.com>"]
version = "0.1.0"

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8"
CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d"
IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"
Expand All @@ -19,6 +20,7 @@ CUDAExt = "CUDA"

[compat]
julia = "1.11"
BFloat16s = "0.6"
CUDA_Compiler_jll = "0.4"
CUDA_Tile_jll = "13.1"

Expand Down
14 changes: 7 additions & 7 deletions examples/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,19 @@ function test_matmul(::Type{T}, M, N, K, tm, tn, tk; name=nothing) where T
println(" passed")
end

function main()
function main(T=Float32)
println("--- cuTile Matrix Multiplication Examples ---\n")

# Small matrices with Float32
test_matmul(Float32, 256, 256, 256, 32, 32, 32)
test_matmul(Float32, 512, 512, 512, 64, 64, 64)
# Small matrices
test_matmul(T, 256, 256, 256, 32, 32, 32)
test_matmul(T, 512, 512, 512, 64, 64, 64)

# Non-square matrices
test_matmul(Float32, 256, 512, 128, 32, 32, 32)
test_matmul(Float32, 512, 256, 384, 64, 64, 64)
test_matmul(T, 256, 512, 128, 32, 32, 32)
test_matmul(T, 512, 256, 384, 64, 64, 64)

# Larger matrices
test_matmul(Float32, 1024, 1024, 1024, 32, 32, 32)
test_matmul(T, 1024, 1024, 1024, 32, 32, 32)

println("\n--- All matmul examples completed ---")
end
Expand Down
2 changes: 2 additions & 0 deletions src/bytecode/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ function julia_to_tile_dtype!(table::TypeTable, ::Type{T}) where T
I64(table)
elseif T === Float16
F16(table)
elseif T === BFloat16
BF16(table)
elseif T === Float32
F32(table)
elseif T === TFloat32
Expand Down
4 changes: 4 additions & 0 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ function float_to_bits(value::Float64, ::Type{Float16})
reinterpret(UInt16, Float16(value))
end

function float_to_bits(value::Float64, ::Type{BFloat16})
reinterpret(UInt16, BFloat16(value))
end

function float_to_bits(value::Float64, ::Type{Float32})
reinterpret(UInt32, Float32(value))
end
Expand Down
3 changes: 3 additions & 0 deletions src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ const CC = Core.Compiler

using CUDA_Tile_jll

using BFloat16s: BFloat16
public BFloat16

# Bytecode infrastructure
include("bytecode/basic.jl")
include("bytecode/types.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ primitive type TFloat32 <: AbstractFloat 32 end
"""Scalar integer types supported by Tile IR (i8, i16, i32, i64)."""
const ScalarInt = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}

"""Scalar floating-point types supported by Tile IR (f16, tf32, f32, f64)."""
const ScalarFloat = Union{Float16, Float32, Float64, TFloat32}
"""Scalar floating-point types supported by Tile IR (f16, bf16, tf32, f32, f64)."""
const ScalarFloat = Union{Float16, BFloat16, Float32, Float64, TFloat32}

"""Integer tile types."""
const TileInt{S} = Tile{T, S} where {T <: ScalarInt}
Expand Down
13 changes: 13 additions & 0 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,19 @@
return
end
end

# Float32 -> BFloat16
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
converted = convert(ct.Tile{ct.BFloat16}, tile)
ct.store(b, pid, ct.astype(converted, Float32))
return
end
end
end
end

Expand Down
21 changes: 21 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,27 @@ end
@test Array(c) ≈ Array(a) + Array(b)
end

@testset "BFloat16" begin
function vadd_bf16(a::ct.TileArray{ct.BFloat16,1}, b::ct.TileArray{ct.BFloat16,1},
c::ct.TileArray{ct.BFloat16,1})
pid = ct.bid(1)
tile_a = ct.load(a, pid, (16,))
tile_b = ct.load(b, pid, (16,))
ct.store(c, pid, tile_a + tile_b)
return
end

n = 1024
tile_size = 16
a = CUDA.rand(ct.BFloat16, n)
b = CUDA.rand(ct.BFloat16, n)
c = CUDA.zeros(ct.BFloat16, n)

ct.launch(vadd_bf16, cld(n, tile_size), a, b, c)

@test Array(c) ≈ Array(a) + Array(b)
end

end

@testset "compilation cache" begin
Expand Down