diff --git a/examples/fmha.jl b/examples/fmha.jl new file mode 100644 index 0000000..e069891 --- /dev/null +++ b/examples/fmha.jl @@ -0,0 +1,221 @@ +# Fused Multi-Head Attention example - Julia port of cuTile Python's AttentionFMHA.py sample +# +# SPDX-License-Identifier: Apache-2.0 + +using CUDA +import cuTile as ct + +import NNlib + +const INV_LOG_2 = Float32(1 / log(2)) +const ConstInt = ct.Constant{Int} +const ConstBool = ct.Constant{Bool} + +# TODO: "latency" + +# cuTile kernel for Fused Multi-Head Attention +# Q: d x +function fmha_kernel( + Q::ct.TileArray{T,4}, K::ct.TileArray{T,4}, V::ct.TileArray{T,4}, Out::ct.TileArray{T,4}, + qk_scale::AbstractFloat, + input_pos::Integer, + TILE_D::ConstInt, + H::ConstInt, # number of heads? + TILE_M::ConstInt, + TILE_N::ConstInt, + QUERY_GROUP_SIZE::ConstInt, + CAUSAL::ConstBool, + EVEN_K::ConstBool +) where T + bid_x = ct.bid(1) + bid_y = ct.bid(2) + batch_idx = cld(bid_y, H[]) + head_idx = mod1(bid_y, H[]) + off_kv_h = cld(head_idx, QUERY_GROUP_SIZE[]) + + qk_scale = Float32(qk_scale) * Float32(INV_LOG_2) + + # Offsets for query tile (M-dimension) + offs_m = bid_x * TILE_M[] .+ ct.arange((TILE_M[],), Int32) .+ input_pos + + # local offsets for key/value tile (N-dimension) + offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32), (1, TILE_N[])) + + # online softmax accumulators in Float32 for stability + m_i = ct.full((1, TILE_M[]), -Inf32, Float32) + l_i = ct.zeros((1, TILE_M[]), Float32) + acc = ct.zeros((TILE_D[], TILE_M[]), Float32) + + # query tile for this batch, head, and M-chunk + q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (TILE_D[], TILE_M[], 1, 1)) + q = ct.reshape(q, (TILE_D[], TILE_M[])) + + m_end = input_pos + (bid_x + 1) * TILE_M[] + k_seqlen = K.sizes[2] + if CAUSAL[] + # when kv pos could exceed q pos + mask_start = cld(input_pos + bid_x * TILE_M[], TILE_N[]) + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, cld(k_seqlen, TILE_N[])) + Tc = cld(min(m_end, k_seqlen), TILE_N[]) + else + Tc = cld(k_seqlen, TILE_N[]) + mask_start = cld(k_seqlen, TILE_N[]) + end + + # loop over K, V blocks (N-dimension chunks) + j = Int32(1) + while j <= Tc + k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + k = ct.reshape(k, (TILE_D[], TILE_N[])) + k = ct.transpose(k) + + qk = ct.zeros((TILE_N[], TILE_M[]), Float32) + qk = ct.muladd(k, q, qk) + + if (CAUSAL[] || !EVEN_K[]) && j >= mask_start + offs_n = j * TILE_N[] + offs_n_tile + mask = ct.full((TILE_N[], TILE_M[]), true, Bool) + if !EVEN_K[] + mask = mask .& (offs_n .< k_seqlen) + end + if CAUSAL[] + mask = mask .& (offs_m .>= offs_n) + end + mask = ct.where(mask, -Inf32, Float32) + qk = qk .+ mask + end + + # moving qk_scale multiplication after reduce_max + m_ij = max.(m_i, (ct.reduce_max(qk, 1) * qk_scale)) + qk = qk * qk_scale .- m_ij + + # attention weights [TILE_N, TILE_M] + p = exp2.(qk) # might need to expose "flush_to_zero" + l_ij = ct.reduce_sum(p, 1) + alpha = exp2.(m_i .- m_ij) # flush to zero? + + l_i = l_i .* alpha .+ l_ij + acc = acc .* alpha + + v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1)) + v = ct.reshape(v, (TILE_D[], TILE_N[])) + p = ct.astype(p, eltype(q)) + acc = ct.muladd(v, p, acc) # [TILE_D, TILE_M] + m_i = m_ij + + j += Int32(1) + end + + acc = acc ./ l_i # flush to zero? rounding mode? + acc = ct.reshape(acc, (TILE_D[], TILE_M[], 1, 1)) + ct.store(Out, (1, bid_x, head_idx, batch_idx), acc) + + return +end + +function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; + qk_scale::Union{AbstractFloat,Nothing} = nothing, + input_pos::Integer = 0, + tile_m::Integer = 128, + tile_n::Integer = 128, + query_group_size::Integer = 1, + causal::Bool = false, +) where T + if size(Q, 4) != size(K, 4) || size(Q, 4) != size(V, 4) + throw(ArgumentError("Batch dimensions must match for Q, K, V.")) + end + if size(Q, 3) % query_group_size != 0 + throw(ArgumentError("Number of query heads must be divisible by query_group_size.")) + end + if size(K, 3) * query_group_size != size(Q, 3) + throw(ArgumentError("K_heads * query_group_size must equal Q_heads.")) + end + if size(Q, 1) != size(K, 1) + throw(ArgumentError("D_k (first dim of Q and K) must match.")) + end + if size(K, 2) != size(V, 2) + throw(ArgumentError("SeqLen_KV (dim 2 of K and V) must match.")) + end + + D_k, SeqLen_Q, Heads, Batch = size(Q) + D_v, SeqLen_KV, KV_heads, _ = size(V) + even_k = (SeqLen_KV % tile_n) == 0 + + isnothing(qk_scale) && (qk_scale = 1 / sqrt(D_k)) + + Out = CUDA.zeros(T, D_v, SeqLen_Q, Heads, Batch) + + grid_x = cld(SeqLen_Q, tile_m) + grid_y = Heads * Batch + grid = (grid_x, grid_y, 1) + + ct.launch(fmha_kernel, grid, + Q, K, V, Out, + qk_scale, input_pos, + ct.Constant(D_k), + ct.Constant(Heads), + ct.Constant(tile_m), + ct.Constant(tile_n), + ct.Constant(query_group_size), + ct.Constant(causal), + ct.Constant(even_k)) + + return Out +end + +function nnlib_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4}; + query_group_size::Integer = 1, + causal::Bool = false, +) where T + mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing + if query_group_size > 1 + K, V = repeat.((K, V), inner=(1, 1, query_group_size, 1)) + end + Out, _ = NNlib.dot_product_attention(Q, K, V; mask) + return Out +end + + +function test_fmha(::Type{T}, + D_k, SeqLen_Q, Heads, Batch, + D_v, SeqLen_KV, KV_heads, + causal, tile_m, tile_n, +) where T + query_group_size = Heads รท KV_heads + + Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch) + K = CUDA.randn(T, D_k, SeqLen_KV, KV_heads, Batch) + V = CUDA.randn(T, D_v, SeqLen_KV, KV_heads, Batch) + + out_cutile = cutile_fmha(Q, K, V; + causal=causal, + tile_m=tile_m, tile_n=tile_n, + query_group_size=query_group_size) + + Q_cpu = Array(Q) + K_cpu = Array(K) + V_cpu = Array(V) + expected = nnlib_fmha(Q_cpu, K_cpu, V_cpu; query_group_size, causal) + result = Array(out_cutile) + + if isapprox(result, expected, rtol=1e-2, atol=1e-2) + println(" passed") + else + max_diff = maximum(abs.(result - expected)) + println(" FAILED (max diff: $max_diff)") + end +end + +function main() + println("--- cuTile Fused Multi-Head Attention Examples ---\n") + + # Float32 tests, causal=false + test_fmha(Float32, 64, 256, 8, 2, 64, 256, 8, false, 32, 32) + test_fmha(Float32, 64, 256, 8, 2, 64, 128, 8, false, 32, 32) + test_fmha(Float32, 64, 256, 8, 2, 64, 128, 4, false, 32, 32) + + println("\n--- All batch matmul examples completed ---") +end + +isinteractive() || main() diff --git a/src/language/arithmetic.jl b/src/language/arithmetic.jl index 1a69f39..99a040b 100644 --- a/src/language/arithmetic.jl +++ b/src/language/arithmetic.jl @@ -29,6 +29,7 @@ @overlay Base.div(x::T, y::T, ::typeof(RoundUp)) where {T <: Unsigned} = Intrinsics.cldi(x, y, SignednessUnsigned) @overlay Base.rem(x::T, y::T) where {T <: Signed} = Intrinsics.remi(x, y, SignednessSigned) @overlay Base.rem(x::T, y::T) where {T <: Unsigned} = Intrinsics.remi(x, y, SignednessUnsigned) +@overlay Base.mod1(x::T, y::T) where {T <: ScalarInt} = (m = mod(x, y); m == zero(m) ? y : m) # float @overlay Base.:+(x::T, y::T) where {T <: ScalarFloat} = Intrinsics.addf(x, y) @@ -77,7 +78,7 @@ @inline Base.:(-)(a::Tile{T, S}, b::Tile{T, S}) where {T <: Integer, S} = Intrinsics.subi(a, b) # broadcasted arithmetic (float) -for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf)) +for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf), (:max, :maxf), (:min, :minf)) @eval @inline function Base.Broadcast.broadcasted(::TileStyle, ::typeof($op), a::Tile{T,S1}, b::Tile{T,S2}) where {T<:AbstractFloat,S1,S2} S = broadcast_shape(S1, S2) Intrinsics.$intrinsic(broadcast_to(a, S), broadcast_to(b, S)) @@ -157,7 +158,7 @@ end @inline Base.:(/)(a::Tile{T, S}, b::Number) where {T <: AbstractFloat, S} = Intrinsics.divf(a, broadcast_to(Tile(T(b)), S)) # broadcasted arithmetic (float) -for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf)) +for (op, intrinsic) in ((:+, :addf), (:-, :subf), (:*, :mulf), (:/, :divf), (:max, :maxf), (:min, :minf)) @eval begin @inline Base.Broadcast.broadcasted(::TileStyle, ::typeof($op), a::Tile{T,S}, b::Number) where {T<:AbstractFloat,S} = Intrinsics.$intrinsic(a, broadcast_to(Tile(T(b)), S)) diff --git a/test/Project.toml b/test/Project.toml index c98454c..e340931 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FileCheck = "4e644321-382b-4b05-b0b6-5d23c3d944fb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"