-
Notifications
You must be signed in to change notification settings - Fork 75
Description
Problem Description
The triton backend flash attention crashed when dropout sets to non-zero, but runs fine when dropout is 0 and return_attn_probs=False.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/core.py:35, in builtin.<locals>.wrapper(*args, **kwargs)
33 raise ValueError("Did you forget to add @triton.jit ? "
34 "(`_builder` argument must be provided outside of JIT functions.)")
---> 35 return fn(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/core.py:1710, in store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
1709 eviction_policy = _constexpr_to_value(eviction_policy)
-> 1710 return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/language/semantic.py:1272, in store(ptr, val, mask, boundary_check, cache_modifier, eviction_policy, builder)
1270 eviction = _str_to_eviction_policy(eviction_policy)
-> 1272 if ptr.type.is_const() or ptr.type.scalar.is_const():
1273 raise ValueError("Cannot store to a constant pointer")
AttributeError: 'constexpr' object has no attribute 'type'
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last)
CompilationError: at 95:12:
if ENABLE_DROPOUT:
if tl_DROPOUT_USE_PYTORCH:
dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask)
else:
rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance
dropout_mask = rng_output > dropout_p
if tl_DROPOUT_DUMP:
tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask)
# return scores with negative values for dropped vals
sd_mask = tl.where(dropout_mask, p, -p)
tl.store(sd_mask_ptrs, sd_mask, mask=p_mask)
^
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last)
Cell In[8], line 2
1 # %%
----> 2 head, seq = model(timestep.cuda(), token_id.cuda(), count.cuda(), keep_seq=True)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
...
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
...
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
...
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
...
--> 119 flash_attn_out = flash_attn.flash_attn_varlen_qkvpacked_func(
120 qkv=qkv.contiguous(),
121 cu_seqlens=offsets.to(torch.int32),
122 max_seqlen=max_len,
123 dropout_p=dropout,
124 softmax_scale=1.0 / (head_dim**0.5),
125 causal=is_causal,
126 return_attn_probs=need_weights,
127 )
128 if need_weights:
129 out: Tensor
File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:1272, in flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs)
1222 def flash_attn_varlen_qkvpacked_func(
1223 qkv,
1224 cu_seqlens,
(...) 1233 return_attn_probs=False,
1234 ):
1235 """dropout_p should be set to 0.0 during evaluation
1236 If Q, K, V are already stacked into 1 tensor, this function will be faster than
1237 calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
(...) 1270 pattern (negative means that location was dropped, nonnegative means it was kept).
1271 """
-> 1272 return FlashAttnVarlenQKVPackedFunc.apply(
1273 qkv,
1274 cu_seqlens,
1275 max_seqlen,
1276 dropout_p,
1277 softmax_scale,
1278 causal,
1279 window_size,
1280 softcap,
1281 alibi_slopes,
1282 deterministic,
1283 return_attn_probs,
1284 torch.is_grad_enabled(),
1285 )
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/autograd/function.py:581, in Function.apply(cls, *args, **kwargs)
578 if not torch._C._are_functorch_transforms_active():
579 # See NOTE: [functorch vjp and autograd interaction]
580 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 581 return super().apply(*args, **kwargs) # type: ignore[misc]
583 if not is_setup_ctx_defined:
584 raise RuntimeError(
585 "In order to use an autograd.Function with functorch transforms "
586 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
587 "staticmethod. For more details, please see "
588 "https://pytorch.org/docs/main/notes/extending.func.html"
589 )
File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:558, in FlashAttnVarlenQKVPackedFunc.forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled)
556 k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
557 v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
--> 558 out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
559 q,
560 k,
561 v,
562 cu_seqlens,
563 cu_seqlens,
564 max_seqlen,
565 max_seqlen,
566 dropout_p,
567 softmax_scale,
568 causal=causal,
569 window_size_left=window_size[0],
570 window_size_right=window_size[1],
571 softcap=softcap,
572 alibi_slopes=alibi_slopes,
573 return_softmax=return_softmax and dropout_p > 0,
574 block_table=None,
575 )
576 if is_grad:
577 ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_ops.py:1255, in OpOverloadPacket.__call__(self, *args, **kwargs)
1253 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1254 return _call_overload_packet_from_python(self, *args, **kwargs)
-> 1255 return self._op(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/autograd.py:111, in make_autograd_impl.<locals>.autograd_impl(keyset, *args, **keyword_only_args)
109 result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
110 else:
--> 111 result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
112 return result
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/autograd.py:40, in make_autograd_impl.<locals>.forward_no_grad(*args)
38 keyset = metadata.keyset
39 kwargs = metadata.keyword_only_args
---> 40 result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
41 return result
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_ops.py:848, in OpOverload.redispatch(self, keyset, *args, **kwargs)
845 def redispatch(
846 self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs
847 ) -> _T:
--> 848 return self._handle.redispatch_boxed(keyset, *args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/custom_ops.py:343, in CustomOpDef.register_kernel.<locals>.inner.<locals>.backend_impl(*args, **kwargs)
342 def backend_impl(*args, **kwargs):
--> 343 result = self._backend_fns[device_type](*args, **kwargs)
345 def get_module():
346 fn = self._backend_fns[device_type]
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_compile.py:53, in _disable_dynamo.<locals>.inner(*args, **kwargs)
50 disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
51 fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
---> 53 return disable_fn(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1044, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
1042 _maybe_set_eval_frame(_callback_from_stance(self.callback))
1043 try:
-> 1044 return fn(*args, **kwargs)
1045 finally:
1046 set_eval_frame(None)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/torch/_library/custom_ops.py:376, in CustomOpDef.register_kernel.<locals>.inner.<locals>.wrapped_fn(*args, **kwargs)
374 return self._init_fn(*args, **kwargs)
375 else:
--> 376 return fn(*args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_interface.py:168, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size_left, window_size_right, softcap, alibi_slopes, return_softmax, block_table, leftpad_k, seqused_k, zero_tensors)
145 @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
146 def _flash_attn_varlen_forward(
147 q: torch.Tensor,
(...) 165 zero_tensors: bool = False,
166 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
167 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
--> 168 out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
169 q,
170 k,
171 v,
172 None,
173 cu_seqlens_q,
174 cu_seqlens_k,
175 seqused_k,
176 leftpad_k,
177 block_table,
178 alibi_slopes,
179 max_seqlen_q,
180 max_seqlen_k,
181 dropout_p,
182 softmax_scale,
183 zero_tensors,
184 causal,
185 window_size_left,
186 window_size_right,
187 softcap,
188 return_softmax,
189 None,
190 )
191 # if out.isnan().any() or softmax_lse.isnan().any():
192 # breakpoint()
193 return out, softmax_lse, S_dmask, rng_state
File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_triton_amd/interface_fa.py:456, in varlen_fwd(q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, zero_tensors, causal, window_size_left, window_size_right, softcap, return_softmax, gen_, descale_q, descale_k, descale_v, descale_o)
454 if DEBUG:
455 print("Using Triton implementation")
--> 456 softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl(
457 q,
458 k,
459 v,
460 out,
461 metadata.sm_scale,
462 metadata.alibi_slopes,
463 metadata.causal,
464 None,
465 metadata.layout,
466 metadata.cu_seqlens_q,
467 metadata.cu_seqlens_k,
468 metadata.max_seqlens_q,
469 metadata.max_seqlens_k,
470 metadata.cache_seqlens,
471 metadata.cache_batch_idx,
472 metadata.dropout_p,
473 metadata.philox_seed,
474 metadata.philox_offset,
475 metadata.return_scores,
476 metadata.use_exp2,
477 descale_q,
478 descale_k,
479 descale_v,
480 descale_o)
481 softmax_lse=softmax_lse_triton
482 sd_mask=sd_mask_triton
File ~/miniforge3/envs/.../lib/python3.12/site-packages/flash_attn/flash_attn_triton_amd/fwd_prefill.py:638, in attention_prefill_forward_triton_impl(q, k, v, o, sm_scale, alibi_slopes, causal, bias, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, cache_seqlens, cache_batch_idx, dropout_p, philox_seed, philox_offset, return_softmax, use_exp2, descale_q, descale_k, descale_v, descale_o)
635 else:
636 bias_strides = (0, 0, 0, 0)
--> 638 attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx,
639 descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z,
640 sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
641 *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
642 dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
643 HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
644 MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference,
645 BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
646 USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p
647 > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT)
649 return softmax_lse, sd_mask if return_softmax else None
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/jit.py:330, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
324 def __getitem__(self, grid) -> T:
325 """
326 A JIT function is launched with: fn[grid](*args, **kwargs).
327 Hence JITFunction.__getitem__ returns a callable proxy that
328 memorizes the grid.
329 """
--> 330 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/autotuner.py:203, in Autotuner.run(self, *args, **kwargs)
201 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
202 config.pre_hook(full_nargs)
--> 203 ret = self.fn.run(
204 *args,
205 **kwargs,
206 **config.all_kwargs(),
207 )
208 self.nargs = None
209 return ret
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/runtime/jit.py:623, in JITFunction.run(self, grid, warmup, *args, **kwargs)
621 # compile the kernel
622 src = self.ASTSource(self, signature, constants, configs[0])
--> 623 kernel = self.compile(
624 src,
625 target=target,
626 options=options.__dict__,
627 )
628 self.cache[device][key] = kernel
629 self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/compiler/compiler.py:273, in compile(src, target, options)
271 module_map = backend.get_module_map()
272 try:
--> 273 module = src.make_ir(options, codegen_fns, module_map, context)
274 except Exception as e:
275 filter_traceback(e)
File ~/miniforge3/envs/.../lib/python3.12/site-packages/triton/compiler/compiler.py:100, in ASTSource.make_ir(self, options, codegen_fns, module_map, context)
99 def make_ir(self, options, codegen_fns, module_map, context):
--> 100 return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
101 module_map=module_map)
CompilationError: at 178:24:
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
# In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its actual
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
Operating System
Ubuntu 25.10 (Questing Quokka)
CPU
AMD Ryzen 9 5950X 16-Core Processor
GPU
Radeon RX 7900 XTX
ROCm Version
ROCm 6.4
ROCm Component
No response
Steps to Reproduce
import torch
import os
if torch.cuda.is_available():
if (
"Radeon" in torch.cuda.get_device_name()
or "Instinct" in torch.cuda.get_device_name()
):
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
import flash_attn
def reproduce():
# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
print(f"Running on {device} with {dtype}")
# --- 1. Simulate Data Parameters ---
# Based on your code logic
batch_size = 2
seqlens = [16, 32]
total_len = sum(seqlens) # 48
num_heads = 4
head_dim = 64
# CRITICAL: The crash happens when dropout > 0.0
dropout_p = 0.1
# --- 2. Construct the "Trick" Tensor ---
# Your code concatenates an extra dimension to Q, K, and V
# Normal QKV: (Total_Len, 3, Num_Heads, Head_Dim)
qkv_normal = torch.randn(total_len, 3, num_heads, head_dim, device=device, dtype=dtype)
# The "Extension" for bias: (Total_Len, 3, Num_Heads, 1)
qkv_ext = torch.randn(total_len, 3, num_heads, 1, device=device, dtype=dtype)
# Concatenate -> Effective Head Dim becomes 65 (64 + 1)
# This weird odd dimension often forces Triton to compile a specific kernel path
qkv_packed = torch.cat([qkv_normal, qkv_ext], dim=-1).contiguous()
print(f"Input Shape (with bias trick): {qkv_packed.shape}")
# Expected: [48, 3, 4, 65]
# --- 3. Prepare Metadata ---
# offsets / cu_seqlens
cu_seqlens = torch.tensor([0, 16, 48], dtype=torch.int32, device=device)
max_seqlen = 32
# --- 4. Trigger the Crash ---
print("\nAttempting Flash Attention call...")
flash_attn.flash_attn_varlen_qkvpacked_func(
qkv=qkv_packed,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=dropout_p, # <--- THE TRIGGER
softmax_scale=1.0 / (head_dim**0.5),
causal=False,
return_attn_probs=False,
)
print("Success: Kernel ran without error.")
if __name__ == "__main__":
reproduce()
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
:~$ /opt/rocm/bin/rocminfo --support
ROCk module is loaded
HSA System Attributes
Runtime Version: 1.18
Runtime Ext Version: 1.11
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
XNACK enabled: NO
DMAbuf Support: YES
VMM Support: YES
==========
HSA Agents
Agent 1
Name: AMD Ryzen 9 5950X 16-Core Processor
Uuid: CPU-XX
Marketing Name: AMD Ryzen 9 5950X 16-Core Processor
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 5086
BDFID: 0
Internal Node ID: 0
Compute Unit: 32
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Memory Properties:
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 4
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 65751044(0x3eb4804) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:
Agent 2
Name: gfx1100
Uuid: GPU-e54cff94d6bbb381
Marketing Name: Radeon RX 7900 XTX
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 6144(0x1800) KB
L3: 98304(0x18000) KB
Chip ID: 29772(0x744c)
ASIC Revision: 0(0x0)
Cacheline Size: 128(0x80)
Max Clock Freq. (MHz): 2304
BDFID: 2304
Internal Node ID: 1
Compute Unit: 96
SIMDs per CU: 2
Shader Engines: 6
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Memory Properties:
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 602
SDMA engine uCode:: 27
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 25149440(0x17fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Recommended Granule:0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1100
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
ISA 2
Name: amdgcn-amd-amdhsa--gfx11-generic
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
*** Done ***
Additional Information
pytorch-triton-rocm==3.5.1
torch==2.9.1+rocm6.4
torchaudio==2.9.1+rocm6.4
torchmetrics==1.8.2
torchvision==0.24.1+rocm6.4