Skip to content

Conversation

@AntonOresten
Copy link
Contributor

@AntonOresten AntonOresten commented Jan 10, 2026

See also #15

Seems to fall slightly short of my NNop / ONIONop baseline (no WMMA), although I haven't compared it to the Python version. On my GPU, it compiles and runs fastest with tile_n=32 and tile_m=32:

julia> begin
           T = Float32
           D, QL, KL, H, B = 64, 4096, 4096, 4, 4
           q = CUDA.randn(T, D, QL, H, B)
           k = CUDA.randn(T, D, KL, H, B)  
           v = CUDA.randn(T, D, KL, H, B)
       end;

julia> @b CUDA.@sync ONIONop.flash_attention(q, k, v, causal=false)
9.559 ms (339 allocs: 7.875 KiB)

julia> @b CUDA.@sync cutile_fmha(q, k, v, causal=false, tile_m=32, tile_n=32)
11.058 ms (540 allocs: 23.109 KiB)

Notably, cutile-python has a latency argument for ct.load, as well as num_ctas and occupancy arguments for the kernel, which might affect performance. The python version also does a kernel config autotune by searching a space of hand-picked configurations.

Currently only tested on Float32.

Another thing that might be important for correctness or covering edge cases is exposing flush_to_zero? Used in e.g. exp2.

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Jan 17, 2026

Seeing some weird erroring when branching:

This works:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

but this doesn't:

        if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

nor does this:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

after removing the second condition, I can suddenly have a nested if block, and I don't need the outer else block:

        if !EVEN_K[]
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask = mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

Does the if block need to depend on compile time constants?

I'd need this to make the padding and causal mask properly.

@maleadt
Copy link
Member

maleadt commented Jan 19, 2026

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

That's an IRStructurizer error. Can you provide an MWE?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants