Skip to content

[Issue]: Program crashed on Triton backend when dropout != 0.0 #167

@Logiquo

Description

@Logiquo

Problem Description

The triton backend flash attention crashed when dropout sets to non-zero, but runs fine when dropout is 0 and return_attn_probs=False.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/core.py:35, in builtin.<locals>.wrapper(*args, **kwargs)
     33     raise ValueError("Did you forget to add @triton.jit ? "
     34                      "(`_builder` argument must be provided outside of JIT functions.)")
---> 35 return fn(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/core.py:1710, in store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
   1709 eviction_policy = _constexpr_to_value(eviction_policy)
-> 1710 return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/semantic.py:1272, in store(ptr, val, mask, boundary_check, cache_modifier, eviction_policy, builder)
   1270 eviction = _str_to_eviction_policy(eviction_policy)
-> 1272 if ptr.type.is_const() or ptr.type.scalar.is_const():
   1273     raise ValueError("Cannot store to a constant pointer")

AttributeError: 'constexpr' object has no attribute 'type'

The above exception was the direct cause of the following exception:

CompilationError                          Traceback (most recent call last)
CompilationError: at 95:12:
        if ENABLE_DROPOUT:
            if tl_DROPOUT_USE_PYTORCH:
                dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask)
            else:
                rng_output = tl.rand(philox_seed, philox_ptrs)  # TODO: use tl.randint for better performance
                dropout_mask = rng_output > dropout_p
                if tl_DROPOUT_DUMP:
                    tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask)

            # return scores with negative values for dropped vals
            sd_mask = tl.where(dropout_mask, p, -p)
            tl.store(sd_mask_ptrs, sd_mask, mask=p_mask)
            ^

The above exception was the direct cause of the following exception:

CompilationError                          Traceback (most recent call last)
Cell In[8], line 2
      1 # %%
----> 2 head, seq = model(timestep.cuda(), token_id.cuda(), count.cuda(), keep_seq=True)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

...

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

...

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

...

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

...
--> 119 flash_attn_out = flash_attn.flash_attn_varlen_qkvpacked_func(
    120     qkv=qkv.contiguous(),
    121     cu_seqlens=offsets.to(torch.int32),
    122     max_seqlen=max_len,
    123     dropout_p=dropout,
    124     softmax_scale=1.0 / (head_dim**0.5),
    125     causal=is_causal,
    126     return_attn_probs=need_weights,
    127 )
    128 if need_weights:
    129     out: Tensor

File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:1272, in flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs)
   1222 def flash_attn_varlen_qkvpacked_func(
   1223     qkv,
   1224     cu_seqlens,
   (...)   1233     return_attn_probs=False,
   1234 ):
   1235     """dropout_p should be set to 0.0 during evaluation
   1236     If Q, K, V are already stacked into 1 tensor, this function will be faster than
   1237     calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
   (...)   1270             pattern (negative means that location was dropped, nonnegative means it was kept).
   1271     """
-> 1272     return FlashAttnVarlenQKVPackedFunc.apply(
   1273         qkv,
   1274         cu_seqlens,
   1275         max_seqlen,
   1276         dropout_p,
   1277         softmax_scale,
   1278         causal,
   1279         window_size,
   1280         softcap,
   1281         alibi_slopes,
   1282         deterministic,
   1283         return_attn_probs,
   1284         torch.is_grad_enabled(),
   1285     )

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/autograd/function.py:581, in Function.apply(cls, *args, **kwargs)
    578 if not torch._C._are_functorch_transforms_active():
    579     # See NOTE: [functorch vjp and autograd interaction]
    580     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 581     return super().apply(*args, **kwargs)  # type: ignore[misc]
    583 if not is_setup_ctx_defined:
    584     raise RuntimeError(
    585         "In order to use an autograd.Function with functorch transforms "
    586         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    587         "staticmethod. For more details, please see "
    588         "https://pytorch.org/docs/main/notes/extending.func.html"
    589     )

File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:558, in FlashAttnVarlenQKVPackedFunc.forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled)
    556     k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
    557     v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
--> 558 out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
    559     q,
    560     k,
    561     v,
    562     cu_seqlens,
    563     cu_seqlens,
    564     max_seqlen,
    565     max_seqlen,
    566     dropout_p,
    567     softmax_scale,
    568     causal=causal,
    569     window_size_left=window_size[0],
    570     window_size_right=window_size[1],
    571     softcap=softcap,
    572     alibi_slopes=alibi_slopes,
    573     return_softmax=return_softmax and dropout_p > 0,
    574     block_table=None,
    575 )
    576 if is_grad:
    577     ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_ops.py:1255, in OpOverloadPacket.__call__(self, *args, **kwargs)
   1253 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
   1254     return _call_overload_packet_from_python(self, *args, **kwargs)
-> 1255 return self._op(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/autograd.py:111, in make_autograd_impl.<locals>.autograd_impl(keyset, *args, **keyword_only_args)
    109     result = Generated.apply(*args, Metadata(keyset, keyword_only_args))  # type: ignore[attr-defined]
    110 else:
--> 111     result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
    112 return result

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/autograd.py:40, in make_autograd_impl.<locals>.forward_no_grad(*args)
     38 keyset = metadata.keyset
     39 kwargs = metadata.keyword_only_args
---> 40 result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
     41 return result

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_ops.py:848, in OpOverload.redispatch(self, keyset, *args, **kwargs)
    845 def redispatch(
    846     self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs
    847 ) -> _T:
--> 848     return self._handle.redispatch_boxed(keyset, *args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/custom_ops.py:343, in CustomOpDef.register_kernel.<locals>.inner.<locals>.backend_impl(*args, **kwargs)
    342 def backend_impl(*args, **kwargs):
--> 343     result = self._backend_fns[device_type](*args, **kwargs)
    345     def get_module():
    346         fn = self._backend_fns[device_type]

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_compile.py:53, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     50     disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
     51     fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]
---> 53 return disable_fn(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1044, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
   1042 _maybe_set_eval_frame(_callback_from_stance(self.callback))
   1043 try:
-> 1044     return fn(*args, **kwargs)
   1045 finally:
   1046     set_eval_frame(None)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/custom_ops.py:376, in CustomOpDef.register_kernel.<locals>.inner.<locals>.wrapped_fn(*args, **kwargs)
    374     return self._init_fn(*args, **kwargs)
    375 else:
--> 376     return fn(*args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:168, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size_left, window_size_right, softcap, alibi_slopes, return_softmax, block_table, leftpad_k, seqused_k, zero_tensors)
    145 @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
    146 def _flash_attn_varlen_forward(
    147     q: torch.Tensor,
   (...)    165     zero_tensors: bool = False,
    166 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    167     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
--> 168     out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
    169         q,
    170         k,
    171         v,
    172         None,
    173         cu_seqlens_q,
    174         cu_seqlens_k,
    175         seqused_k,
    176         leftpad_k,
    177         block_table,
    178         alibi_slopes,
    179         max_seqlen_q,
    180         max_seqlen_k,
    181         dropout_p,
    182         softmax_scale,
    183         zero_tensors,
    184         causal,
    185         window_size_left,
    186         window_size_right,
    187         softcap,
    188         return_softmax,
    189         None,
    190     )
    191     # if out.isnan().any() or softmax_lse.isnan().any():
    192     #     breakpoint()
    193     return out, softmax_lse, S_dmask, rng_state

File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_triton_amd/interface_fa.py:456, in varlen_fwd(q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, zero_tensors, causal, window_size_left, window_size_right, softcap, return_softmax, gen_, descale_q, descale_k, descale_v, descale_o)
    454 if DEBUG:
    455     print("Using Triton implementation")
--> 456 softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl(
    457                                                     q,
    458                                                     k,
    459                                                     v,
    460                                                     out,
    461                                                     metadata.sm_scale,
    462                                                     metadata.alibi_slopes,
    463                                                     metadata.causal,
    464                                                     None,
    465                                                     metadata.layout,
    466                                                     metadata.cu_seqlens_q,
    467                                                     metadata.cu_seqlens_k,
    468                                                     metadata.max_seqlens_q,
    469                                                     metadata.max_seqlens_k,
    470                                                     metadata.cache_seqlens,
    471                                                     metadata.cache_batch_idx,
    472                                                     metadata.dropout_p,
    473                                                     metadata.philox_seed,
    474                                                     metadata.philox_offset,
    475                                                     metadata.return_scores,
    476                                                     metadata.use_exp2,
    477                                                     descale_q,
    478                                                     descale_k,
    479                                                     descale_v,
    480                                                     descale_o)
    481 softmax_lse=softmax_lse_triton
    482 sd_mask=sd_mask_triton

File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_triton_amd/fwd_prefill.py:638, in attention_prefill_forward_triton_impl(q, k, v, o, sm_scale, alibi_slopes, causal, bias, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, cache_seqlens, cache_batch_idx, dropout_p, philox_seed, philox_offset, return_softmax, use_exp2, descale_q, descale_k, descale_v, descale_o)
    635 else:
    636     bias_strides = (0, 0, 0, 0)
--> 638 attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx,
    639                 descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z,
    640                 sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
    641                 *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
    642                 dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, 
    643                 HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
    644                 MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference,
    645                 BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
    646                 USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p
    647                 > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT)
    649 return softmax_lse, sd_mask if return_softmax else None

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/jit.py:330, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    324 def __getitem__(self, grid) -> T:
    325     """
    326     A JIT function is launched with: fn[grid](*args, **kwargs).
    327     Hence JITFunction.__getitem__ returns a callable proxy that
    328     memorizes the grid.
    329     """
--> 330     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/autotuner.py:203, in Autotuner.run(self, *args, **kwargs)
    201     full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
    202     config.pre_hook(full_nargs)
--> 203 ret = self.fn.run(
    204     *args,
    205     **kwargs,
    206     **config.all_kwargs(),
    207 )
    208 self.nargs = None
    209 return ret

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/jit.py:623, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    621 # compile the kernel
    622 src = self.ASTSource(self, signature, constants, configs[0])
--> 623 kernel = self.compile(
    624     src,
    625     target=target,
    626     options=options.__dict__,
    627 )
    628 self.cache[device][key] = kernel
    629 self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/compiler/compiler.py:273, in compile(src, target, options)
    271 module_map = backend.get_module_map()
    272 try:
--> 273     module = src.make_ir(options, codegen_fns, module_map, context)
    274 except Exception as e:
    275     filter_traceback(e)

File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/compiler/compiler.py:100, in ASTSource.make_ir(self, options, codegen_fns, module_map, context)
     99 def make_ir(self, options, codegen_fns, module_map, context):
--> 100     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
    101                        module_map=module_map)

CompilationError: at 178:24:
        masked_blocks = padded_block_k
    # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
    # In this case we might exceed n_blocks so pick the min.
    masked_blocks = min(masked_blocks, n_blocks)
    n_full_blocks = n_blocks - masked_blocks
    block_min = 0
    block_max = n_blocks * BLOCK_N
    # Compute for full blocks. Here we set causal to false regardless of its actual
    # value because there is no masking. Similarly we do not need padding.
    if n_full_blocks > 0:
        block_max = (n_blocks - masked_blocks) * BLOCK_N
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,   

Operating System

Ubuntu 25.10 (Questing Quokka)

CPU

AMD Ryzen 9 5950X 16-Core Processor

GPU

Radeon RX 7900 XTX

ROCm Version

ROCm 6.4

ROCm Component

No response

Steps to Reproduce

import torch
import os
if torch.cuda.is_available():
    if (
        "Radeon" in torch.cuda.get_device_name()
        or "Instinct" in torch.cuda.get_device_name()
    ):
        os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
import flash_attn

def reproduce():
    # Setup Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float16

    print(f"Running on {device} with {dtype}")

    # --- 1. Simulate Data Parameters ---
    # Based on your code logic
    batch_size = 2
    seqlens = [16, 32] 
    total_len = sum(seqlens) # 48
    num_heads = 4
    head_dim = 64
    
    # CRITICAL: The crash happens when dropout > 0.0
    dropout_p = 0.1 

    # --- 2. Construct the "Trick" Tensor ---
    # Your code concatenates an extra dimension to Q, K, and V
    # Normal QKV: (Total_Len, 3, Num_Heads, Head_Dim)
    qkv_normal = torch.randn(total_len, 3, num_heads, head_dim, device=device, dtype=dtype)
    
    # The "Extension" for bias: (Total_Len, 3, Num_Heads, 1)
    qkv_ext = torch.randn(total_len, 3, num_heads, 1, device=device, dtype=dtype)
    
    # Concatenate -> Effective Head Dim becomes 65 (64 + 1)
    # This weird odd dimension often forces Triton to compile a specific kernel path
    qkv_packed = torch.cat([qkv_normal, qkv_ext], dim=-1).contiguous()
    
    print(f"Input Shape (with bias trick): {qkv_packed.shape}")
    # Expected: [48, 3, 4, 65]

    # --- 3. Prepare Metadata ---
    # offsets / cu_seqlens
    cu_seqlens = torch.tensor([0, 16, 48], dtype=torch.int32, device=device)
    max_seqlen = 32

    # --- 4. Trigger the Crash ---
    print("\nAttempting Flash Attention call...")
    
    flash_attn.flash_attn_varlen_qkvpacked_func(
        qkv=qkv_packed,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
        dropout_p=dropout_p,              # <--- THE TRIGGER
        softmax_scale=1.0 / (head_dim**0.5),
        causal=False,
        return_attn_probs=False,
    )
    print("Success: Kernel ran without error.")

if __name__ == "__main__":
    reproduce()

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

:~$ /opt/rocm/bin/rocminfo --support
ROCk module is loaded

HSA System Attributes

Runtime Version: 1.18
Runtime Ext Version: 1.11
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
XNACK enabled: NO
DMAbuf Support: YES
VMM Support: YES

==========
HSA Agents


Agent 1


Name: AMD Ryzen 9 5950X 16-Core Processor
Uuid: CPU-XX
Marketing Name: AMD Ryzen 9 5950X 16-Core Processor
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 5086
BDFID: 0
Internal Node ID: 0
Compute Unit: 32
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Memory Properties:
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 4
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:


Agent 2


Name: gfx1100
Uuid: GPU-e54cff94d6bbb381
Marketing Name: Radeon RX 7900 XTX
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 6144(0x1800) KB
L3: 98304(0x18000) KB
Chip ID: 29772(0x744c)
ASIC Revision: 0(0x0)
Cacheline Size: 128(0x80)
Max Clock Freq. (MHz): 2304
BDFID: 2304
Internal Node ID: 1
Compute Unit: 96
SIMDs per CU: 2
Shader Engines: 6
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Memory Properties:
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 602
SDMA engine uCode:: 27
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Recommended Granule:0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1100
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
ISA 2
Name: amdgcn-amd-amdhsa--gfx11-generic
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
*** Done ***

Additional Information

pytorch-triton-rocm==3.5.1
torch==2.9.1+rocm6.4
torchaudio==2.9.1+rocm6.4
torchmetrics==1.8.2
torchvision==0.24.1+rocm6.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions