ATOM (AiTer Optimized Model) is AMD's lightweight LLM inference engine built on AITER kernels for ROCm/HIP GPUs. This guide covers the supported model architectures, weight loading, and how to add new models.
The model registry lives in atom/model_engine/model_runner.py as support_model_arch_dict:
support_model_arch_dict = {
"Qwen3ForCausalLM": "atom.models.qwen3.Qwen3ForCausalLM",
"Qwen3MoeForCausalLM": "atom.models.qwen3_moe.Qwen3MoeForCausalLM",
"LlamaForCausalLM": "atom.models.llama.LlamaForCausalLM",
"MixtralForCausalLM": "atom.models.mixtral.MixtralForCausalLM",
"DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
"Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM",
}ATOM resolves the HuggingFace architectures field from a model's config.json against this dictionary. If the architecture string matches a key, ATOM imports and instantiates the corresponding class.
| HF Architecture | ATOM Module | ATOM Class | MoE | MLA | Key Features |
|---|---|---|---|---|---|
Qwen3ForCausalLM |
atom.models.qwen3 |
Qwen3ForCausalLM |
No | No | GQA, QK norm, RoPE |
Qwen3MoeForCausalLM |
atom.models.qwen3_moe |
Qwen3MoeForCausalLM |
Yes | No | GQA, QK norm, FusedMoE, sparse+dense layer mixing, QK norm+RoPE+cache+quant fusion |
LlamaForCausalLM |
atom.models.llama |
LlamaForCausalLM |
No | No | GQA, RoPE, fused RMSNorm+quant, fused SiLU+mul+quant |
MixtralForCausalLM |
atom.models.mixtral |
MixtralForCausalLM |
Yes | No | GQA, RoPE, FusedMoE with TP sharding |
DeepseekV3ForCausalLM |
atom.models.deepseek_v2 |
DeepseekV2ForCausalLM |
Yes | Yes | MLA attention, LoRA-compressed QKV, FusedMoE with shared experts, FP4/FP8 fused kernels |
DeepseekV32ForCausalLM |
atom.models.deepseek_v2 |
DeepseekV2ForCausalLM |
Yes | Yes | Same as above with V3.2 index-based top-k routing |
GptOssForCausalLM |
atom.models.gpt_oss |
GptOssForCausalLM |
Yes | No | GQA, RoPE, sliding window attention (every other layer), attention sinks, bias in QKV and MoE |
Glm4MoeForCausalLM |
atom.models.glm4_moe |
Glm4MoeForCausalLM |
Yes | No | GQA, partial RoPE (0.5 factor), QK norm, shared+routed experts, sigmoid scoring, grouped top-k |
Note: DeepSeekMTP (atom.models.deepseek_mtp.DeepSeekMTP) is not in the registry -- it is used exclusively as a speculative draft model for DeepSeek multi-token prediction and is loaded separately.
- Architecture: Dense transformer with Grouped-Query Attention (GQA).
- Layer structure:
Qwen3DecoderLayercontainingQwen3Attention+Qwen3MLP. - Attention:
QKVParallelLinearfor fused QKV projection, per-head QK RMSNorm (q_norm,k_norm), RoPE,RowParallelLinearfor output projection. - MLP:
MergedColumnParallelLinearfor gate+up projection, SiLU activation,RowParallelLinearfor down projection. - Normalization: RMSNorm on input and post-attention.
- Architecture: Mixture-of-Experts transformer with GQA.
- Layer structure:
Qwen3MoeDecoderLayercontainingQwen3MoeAttention+ eitherQwen3MoeSparseMoeBlock(MoE layers) orQwen3MoeMLP(dense layers, controlled bymlp_only_layersanddecoder_sparse_step). - Attention: Same QKV structure as Qwen3 with QK norm. Supports QK norm + RoPE + cache + quant fusion when
ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSIONis set -- this precomputes a jointcos_sin_cacheand passesq_norm/k_normto theAttentionmodule. - MoE:
FusedMoEwithReplicatedLineargate router. Supports allreduce+RMSNorm fusion (ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION). - Normalization: RMSNorm with optional fused allreduce.
- Architecture: Dense transformer with GQA. Covers Llama 2/3 and compatible architectures (InternLM, Mistral-Nemo via optional
head_dim). - Layer structure:
LlamaDecoderLayercontainingLlamaAttention+LlamaMLP. - Attention:
QKVParallelLinear, RoPE (NeoX or original style based on GGUF), per-layer sliding window support vialayer_typesconfig. - MLP:
MergedColumnParallelLinearfor gate+up, SiLU+mul activation,RowParallelLinearfor down. - Fused optimizations: Controlled by environment variables:
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT-- fuses RMSNorm with FP8/MXFP4 quantization.ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT-- fuses SiLU+mul activation with quantization.
- Pipeline parallelism: Full PP support with
PPMissingLayerplaceholders andIntermediateTensorsfor cross-stage communication. Supports auxiliary hidden state extraction for speculative decoding.
- Architecture: Sparse Mixture-of-Experts with GQA.
- Layer structure:
MixtralDecoderLayercontainingMixtralAttention+MixtralMoE. - Attention: Standard GQA with
QKVParallelLinear, RoPE (NeoX style),RowParallelLinear. - MoE:
MixtralMoEwrapsReplicatedLineargate +FusedMoE. Experts are sharded across TP ranks with full reduce. Gate checkpoint names usew1/w2/w3convention (mapped togate_proj/down_proj/up_proj). - Normalization: RMSNorm.
- Architecture: MoE transformer with Multi-head Latent Attention (MLA).
- Layer structure:
DeepseekV2DecoderLayercontainingDeepseekV2MLAAttention+ eitherDeepseekV2MoE(MoE layers) orDeepseekV2MLP(dense layers). - MLA Attention: Uses LoRA-compressed QKV (
q_lora_rank,kv_lora_rank), separateqk_nope_head_dimandqk_rope_head_dimfor non-positional and rotary-embedded components. Backed byMLAModulesfromatom.model_ops.attention_mla. - MoE:
DeepseekV2MoEwith routed + shared experts. Supports shared expert fusion (is_rocm_aiter_fusion_shared_expert_enabled), routed scaling factor fusion (is_rocm_aiter_fuse_routed_scaling_factor), and grouped top-k routing. - Fused optimizations:
ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION-- fuses input RMSNorm with FP8/FP4 quantization.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION-- fuses QK norm with quantization.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION-- fuses allreduce with RMSNorm.- Dedicated Triton kernels for FP8 MQA logits (
fp8_mqa_logits), paged MQA logits (deepgemm_fp8_paged_mqa_logits), and fused RMSNorm+quantization (_fuse_rmsnorm_quant).
- V3.2 extension:
DeepseekV32ForCausalLMis an alias. TheDeepseekV2Modeldetects V3.2 viaconfig.index_topkand allocates antopk_indices_bufferfor index-based routing. - Note:
DeepseekV3ForCausalLMis a subclass ofDeepseekV2ForCausalLM(pass-through, no override).
- Architecture: Multi-Token Prediction draft model for speculative decoding.
- Layer structure:
DeepSeekMultiTokenPredictorcontaining one or moreDeepSeekMultiTokenPredictorLayer, each withenorm(embedding norm),hnorm(hidden state norm),eh_proj(linear projection joining embedded+hidden),mtp_block(aDeepseekV2DecoderLayer), and aSharedHead(norm + LM head). - Usage: Not registered in
support_model_arch_dict. Loaded separately withspec_decode=Trueinload_model(), which invokesrewrite_spec_layer_name()to remap MTP weight names (e.g., adding.mtp_block.prefix for transformer layer weights, remappingembed_tokensto top-level). - MTP layers start at
config.num_hidden_layers(i.e., the layer indices following the main model layers).
- Architecture: MoE transformer with GQA and alternating sliding window attention.
- Layer structure:
TransformerBlockcontainingOAIAttention+MLPBlock. - Attention:
OAIAttentionwith bias on QKV and output projections, attention sinks (learnable per-head parameters), and sliding window applied on even-indexed layers only. - MoE:
MLPBlockwrapsReplicatedLinearrouter (with bias) +FusedMoEwith SwiGLU activation and bias support. Customweights_mappingtranslates checkpoint names (gate_up_proj_blockstow13_weight, etc.). - Normalization: RMSNorm with eps=1e-5, post-attention norm uses
x_pad_to_multiple=256. - Pipeline parallelism: Supports auxiliary hidden state layers for EAGLE3 speculative decoding (
get_eagle3_aux_hidden_state_layers).
- Architecture: MoE transformer with GQA, shared + routed experts, partial RoPE.
- Layer structure:
Glm4MoeDecoderLayercontainingGlm4MoeAttention+ eitherGlm4MoE(MoE layers, fromfirst_k_dense_replaceonward) orGlm4MoeMLP(dense layers). - Attention:
Glm4MoeAttentionwith optional QK norm (use_qk_norm), partial rotary factor of 0.5. - MoE:
Glm4MoEwith sigmoid scoring,e_score_correction_bias, grouped top-k routing (n_group,topk_group), routed scaling factor. Shared experts handled separately or fused intoFusedMoEviais_rocm_aiter_fusion_shared_expert_enabled(). Expert parallelism (EP) support built in. - Inherits:
Glm4MixtureOfExpertsmixin for MoE metadata management and expert load balancing (EPLB) support.
Weight loading is handled by load_model() in atom/model_loader/loader.py.
def load_model(
model: nn.Module,
model_name_or_path: str,
hf_config: AutoConfig,
load_dummy: bool = False,
spec_decode: bool = False,
):-
SafeTensors iteration:
safetensors_weights_iterator()discovers and iterates over all*.safetensorsfiles in the model directory (or downloads them from HuggingFace Hub viadownload_weights_from_hf()). Duplicate files are filtered using themodel.safetensors.index.jsonweight map. Memory-mapped loading is used by default; setATOM_DISABLE_MMAP=trueto disable. -
Weight name rewriting: Each weight name goes through several transformations:
weight_scale_invis renamed toweight_scale.- Model-specific
weights_mapping(e.g., GPT-OSS mapsgate_up_proj_blockstow13_weight). - For speculative decoding (
spec_decode=True), MTP layer weights are rewritten viarewrite_spec_layer_name(). - Shared expert fusion: when enabled,
mlp.shared_expertsis remapped tomlp.experts.<n_routed_experts>so the shared expert is loaded as the last expert in theFusedMoEmodule.
-
Packed module resolution: The
packed_modules_mappingdict on each model class defines how HuggingFace checkpoint weight names map to ATOM's fused parameter names. For example, Llama maps:"q_proj": ("qkv_proj", "q"), "k_proj": ("qkv_proj", "k"), "v_proj": ("qkv_proj", "v"), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1),
Each packed parameter has a
weight_loaderattribute that knows how to shard and place the weight into the correct slice. -
Expert parameter loading: If the model has a
get_expert_mapping()method, expert weights are loaded usingFusedMoE.make_expert_params_mapping(), which generates (param_name, weight_name, expert_id, shard_id) tuples. This handles per-expert sharding across TP ranks. -
TP sharding: Parallel linear layers (
ColumnParallelLinear,RowParallelLinear,QKVParallelLinear) have customweight_loadermethods that automatically select the correct shard for the current TP rank during loading. The default fallbackdefault_weight_loaderhandles simple cases where weights need to be sliced by TP rank. -
Concurrent loading: All weight loading calls are submitted to a
ThreadPoolExecutorfor parallel execution. -
Post-processing: After all weights are loaded,
process_weights_after_loading()is called on each module (e.g., for weight pre-shuffling, scale computation), andquant_method.process_weights_after_loading()is invoked for quantized modules. ForFusedMoEMethodBase,init_prepare_finalize()is also called.
Layers Beyond num_hidden_layers
Weights for layers with index >= config.num_hidden_layers are skipped during normal loading. These layers (MTP layers) are only loaded when spec_decode=True.
Follow these steps to add support for a new model architecture:
Create a new file in atom/models/, e.g., atom/models/my_model.py. Follow the existing patterns:
from atom.config import Config, QuantizationConfig
from atom.model_ops.base_attention import Attention
from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding
from atom.model_ops.layernorm import RMSNorm
from atom.model_ops.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from atom.models.utils import (
IntermediateTensors,
PPMissingLayer,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from atom.utils.decorators import support_torch_compileEach model typically defines three core module classes:
-
Attention module (e.g.,
MyModelAttention):- Initialize
QKVParallelLinearfor query/key/value. - Initialize
RowParallelLinearfor output projection. - Set up rotary embeddings via
aiter.rotary_embedding.get_rope(). - Create
Attentionfromatom.model_ops.base_attention.
- Initialize
-
MLP module (e.g.,
MyModelMLP):- Use
MergedColumnParallelLinearfor gate+up projections. - Use
RowParallelLinearfor down projection. - For MoE models, use
FusedMoEfromatom.model_ops.moe.
- Use
-
Decoder layer (e.g.,
MyModelDecoderLayer):- Combine attention + MLP with RMSNorm layers.
- Implement the forward pass with residual connections.
-
Backbone model (e.g.,
MyModel):- Decorate with
@support_torch_compile. - Initialize
VocabParallelEmbedding, decoder layers viamake_layers(), and finalRMSNorm. - Support pipeline parallelism with
PPMissingLayerandIntermediateTensors.
- Decorate with
-
CausalLM wrapper (e.g.,
MyModelForCausalLM):- Define
packed_modules_mappingto map checkpoint weight names to ATOM's fused parameter names. - Initialize the backbone model and
ParallelLMHead. - Implement
forward()(returns hidden states) andcompute_logits()(returns logits vialm_head). - If the model uses MoE, implement
get_expert_mapping()returningFusedMoE.make_expert_params_mapping(...).
- Define
Add an entry to support_model_arch_dict in atom/model_engine/model_runner.py:
support_model_arch_dict = {
...
"MyModelForCausalLM": "atom.models.my_model.MyModelForCausalLM",
}The key must exactly match the architectures field in the HuggingFace model's config.json.
Ensure your packed_modules_mapping correctly maps all checkpoint weight names that differ from ATOM's internal names. Common patterns:
| Checkpoint Name | ATOM Parameter | Shard ID |
|---|---|---|
q_proj |
qkv_proj |
"q" |
k_proj |
qkv_proj |
"k" |
v_proj |
qkv_proj |
"v" |
gate_proj |
gate_up_proj |
0 |
up_proj |
gate_up_proj |
1 |
For MoE models, add get_expert_mapping() to delegate to FusedMoE.make_expert_params_mapping() with the correct gate/down/up projection names and expert count.
If the checkpoint uses non-standard weight names (like GPT-OSS), define a weights_mapping class attribute to rename them at load time.
Llama supports two AITER Triton fused kernel optimizations:
-
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT: Fuses the RMSNorm normalization with FP8 or MXFP4 quantization in a single kernel call. Applied to bothinput_layernormandpost_attention_layernorm. Eliminates an extra read/write pass over the hidden states. -
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT: Fuses the SiLU activation, element-wise multiply, and quantization in the MLP. TheSiluAndMulmodule receives thefused_quant=Trueflag and the quant config, producing quantized output directly for the down projection.
Both are controlled by environment variables and read from atom.utils.envs.
DeepSeek models use Multi-head Latent Attention (MLA) with LoRA-compressed projections (q_lora_rank, kv_lora_rank). Several fusion optimizations are available:
-
ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: Fuses the input RMSNorm with quantization. Implemented via_fuse_rmsnorm_quant()which dispatches to either_fuse_rmsnorm_fp4_quant()or_fused_rms_fp8_group_quant()based on the quant dtype. When enabled, the allreduce+RMSNorm fusion is disabled forinput_layernormbut kept forpost_attention_layernorm. -
ATOM_ENABLE_DS_QKNORM_QUANT_FUSION: Fuses the Q/K LoRA layernorm with quantization via_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4()or the FP8 variant, which performs the fused QKV-A projection, RMSNorm on Q and KV components, and quantization in a single fused operation. -
ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION: Fuses tensor-parallel allreduce with RMSNorm. -
FP8 MQA logits:
fp8_mqa_logitsanddeepgemm_fp8_paged_mqa_logitsimplement FP8-precision attention score computation for MLA decode. -
FP4 support: MXFP4 quantized GEMM kernels (
gemm_afp4wfp4_preshuffle,gemm_a16wfp4_preshuffle) and FP4 block-scale BMM viais_rocm_aiter_fp4bmm_enabled().
When ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION is enabled, the Qwen3MoeAttention module:
- Precomputes a joint
cos_sin_cacheby concatenating cosine and sine RoPE caches. - Passes
q_normandk_normdirectly to theAttentionmodule. - The attention backend then fuses QK normalization, RoPE application, KV cache write, and optional quantization into a single kernel pass.
Additionally, ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION fuses allreduce with RMSNorm for both attention output and MoE output, reducing communication overhead.
The DeepSeekMTP model serves as a speculative draft model:
- Each
DeepSeekMultiTokenPredictorLayertakes the previous hidden state and the next token's embedding, normalizes both (enorm,hnorm), concatenates them, and passes through a linear projection (eh_proj) followed by a standardDeepseekV2DecoderLayer. - The
SharedHeadprovides per-layer norm + LM head for logit computation. - For FP4 quantized main models, MTP blocks fall back to non-FP4 quantization config to maintain draft model accuracy.
| File | Description |
|---|---|
atom/model_engine/model_runner.py |
Model registry (support_model_arch_dict) and ModelRunner class |
atom/models/llama.py |
Llama model: LlamaForCausalLM, LlamaModel, LlamaDecoderLayer, LlamaAttention, LlamaMLP |
atom/models/qwen3.py |
Qwen3 model: Qwen3ForCausalLM, Qwen3Model, Qwen3DecoderLayer, Qwen3Attention, Qwen3MLP |
atom/models/qwen3_moe.py |
Qwen3-MoE model: Qwen3MoeForCausalLM, Qwen3MoeModel, Qwen3MoeDecoderLayer, Qwen3MoeAttention, Qwen3MoeSparseMoeBlock, Qwen3MoeMLP |
atom/models/deepseek_v2.py |
DeepSeek V2/V3 model: DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV2Model, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, DeepseekV2MoE, DeepseekV2MLP |
atom/models/deepseek_mtp.py |
DeepSeek MTP draft model: DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead |
atom/models/mixtral.py |
Mixtral model: MixtralForCausalLM, MixtralModel, MixtralDecoderLayer, MixtralAttention, MixtralMoE |
atom/models/gpt_oss.py |
GPT-OSS model: GptOssForCausalLM, GptOssModel, TransformerBlock, OAIAttention, MLPBlock |
atom/models/glm4_moe.py |
GLM4-MoE model: Glm4MoeForCausalLM, Glm4MoeModel, Glm4MoeDecoderLayer, Glm4MoeAttention, Glm4MoE, Glm4MoeMLP |
atom/models/utils.py |
Model utilities: IntermediateTensors, PPMissingLayer, make_layers, maybe_prefix, extract_layer_index, should_ignore_layer, get_quant_config_for_layer |
atom/model_loader/loader.py |
Weight loading: load_model, safetensors_weights_iterator, default_weight_loader |
atom/model_loader/weight_utils.py |
Weight utilities: download_weights_from_hf, set_weight_attrs, filter_duplicate_safetensors_files |