ORION provides an end-to-end machine learning pipeline for predicting atomic transport in crystalline materials:
- Generate grouped defect data with mobility labels (vacancies, interstitials, or both),
- Train a mobility classifier to identify which atoms can hop in a given structure, and
- Train multi-modal displacement models (1-, 2-, 3-atom) that predict multiple possible outcomes for each atomic configuration.
The pipeline handles the inherent multi-modality of defect transport: a single initial structure can lead to many distinct valid final states (e.g., multiple vacancy hop sites, different interstitial positions). The displacement models use multi-head outputs to capture this diversity, and a target replication strategy ensures all model heads learn valid displacements.
The repository ships with an AgCl example structure, but the workflow is fully general and applies to any crystalline material once a structure file is provided.
Installation:
pip install -e .Train models specifically for vacancy-mediated transport:
# 1) Generate vacancy data (train/test split)
python scripts/generate_data.py \
--structure data/AgCl.cif \
--element Ag \
--supercell 2 2 2 \
--defect-type vacancy \
--max-distance 4.0 \
--output-dir output
# 2) Train mobility classifier
python scripts/train_mobility.py \
--data-dir output/vacancy_data \
--output-dir output/mobility_vacancy \
--epochs 50 \
--batch-size 8 \
--lr 1e-3
# 3) Train multi-hop displacement models
python scripts/train_multi_hop_models.py \
--data-dir output/vacancy_data \
--output-dir output/multi_hop_vacancy \
--epochs 500 \
--batch-size 64 \
--mobility-threshold 0.5 \
--save-predictions \
--pred-max-groups 10 \
--pred-max-per-group 20Train models for interstitial-mediated transport (defaults tuned to keep ~6 neighbors within ~3 Å):
# 1) Generate interstitial data
python scripts/generate_data.py \
--structure data/AgCl.cif \
--element Ag \
--supercell 2 2 2 \
--defect-type interstitial \
--energy-threshold 5.0 \
--max-pair-distance 3.0 \
--max-calculations 200 \
--min-neighbors 6 \
--max-distance 4.0 \
--output-dir output
# 2) Train mobility classifier
python scripts/train_mobility.py \
--data-dir output/interstitial_data \
--output-dir output/mobility_interstitial \
--epochs 50 \
--batch-size 8 \
--lr 1e-3
# 3) Train multi-hop displacement models
python scripts/train_multi_hop_models.py \
--data-dir output/interstitial_data \
--output-dir output/multi_hop_interstitial \
--epochs 500 \
--batch-size 64 \
--mobility-threshold 0.5 \
--save-predictions \
--pred-max-groups 10 \
--pred-max-per-group 20Train a unified model on both defect types simultaneously. This is the recommended approach for materials where both mechanisms may be active:
# Use the combined_workflow.sh script (generates data, trains mobility, trains displacement)
bash combined_workflow.sh
# Or run manually:
# 1) Generate both defect types
python scripts/generate_data.py \
--structure data/AgCl.cif \
--element Ag \
--supercell 2 2 2 \
--defect-type both \
--energy-threshold 5.0 \
--max-pair-distance 3.0 \
--max-calculations 200 \
--min-neighbors 6 \
--max-distance 4.0 \
--output-dir output
# (Script automatically merges output/vacancy_data and output/interstitial_data into output/combined_data)
# 2) Train mobility classifier on combined data
python scripts/train_mobility.py \
--data-dir output/combined_data \
--output-dir output/mobility_combined \
--epochs 50 \
--batch-size 8 \
--lr 1e-3
# 3) Train multi-hop models on combined data
python scripts/train_multi_hop_models.py \
--data-dir output/combined_data \
--output-dir output/multi_hop_combined \
--epochs 500 \
--batch-size 64 \
--mobility-threshold 0.5 \
--save-predictions \
--pred-max-groups 10 \
--pred-max-per-group 20 \
--pred-min-disp 0.1Generated artifacts live under the chosen --output-dir values:
vacancy_data/train|test/group_xxxx/: initial.cif, final_*.cif, mobility_labels.ptmobility/: mobility checkpoints (best_model.pt,final_model.pt) and training logsmulti_hop/: hop-specific checkpoints underhop_1/,hop_2/,hop_3/, plus optional predicted CIFs underpredictions/
scripts/generate_data.py: Generate grouped vacancy or interstitial data (COMET-backed), label mobility, and split train/test. Use--random-samplingto fall back to the older vacancy generator; omit--num-pairsto keep all valid interstitial pairs; use--min-neighborsto warn on sparse neighbor counts per interstitial site.scripts/train_mobility.py: GNN-based per-atom classifier that ingests MACE embeddings and predicts mobility masks.scripts/train_multi_hop_models.py: Trains separate 1-, 2-, and 3-atom displacement heads and optionally exports predicted CIFs.
--structure: Path to input CIF file (e.g.,data/AgCl.cif)--element: Mobile element symbol (e.g.,Ag)--supercell: Supercell size (e.g.,2 2 2)--defect-type: One ofvacancy,interstitial, orboth--max-distance: Max neighbor distance for grouping (Å)--mobility-threshold: Min displacement to label atom as mobile (Å, default: 0.5)--train-fraction: Train/test split ratio (default: 0.8)--energy-threshold: Max energy above ground state for interstitials (eV, default: 5.0)--max-pair-distance: Max distance between interstitial pairs (Å, default: 3.0)--max-calculations: Max COMET calculations per interstitial config (default: 200)--min-neighbors: Warn if interstitial site has fewer neighbors (default: 6)--output-dir: Output directory root (creates subdirs for each defect type)
--data-dir: Path to grouped dataset (e.g.,output/vacancy_data)--output-dir: Where to save mobility model checkpoints--cutoff: Graph edge cutoff radius (Å, default: 5.0)--hidden-dim: GNN hidden dimension (default: 128)--num-layers: Number of message-passing layers (default: 3)--batch-size: Training batch size (default: 8)--epochs: Number of training epochs (default: 50)--lr: Learning rate (default: 1e-3)--no-focal: Disable focal loss (use BCE instead)
--data-dir: Path to grouped dataset--output-dir: Where to save displacement model checkpoints--hidden-dim: MLP hidden dimension (default: 256)--num-layers: Number of MLP layers (default: 4)--batch-size: Training batch size (default: 64)--epochs: Number of training epochs (default: 500)--lr: Learning rate (default: 1e-3)--mobility-threshold: Min displacement to consider atom mobile (Å, default: 0.5)--num-modes: Number of output heads (auto-estimated if not provided)--max-modes: Ceiling for auto-estimation (default: 12)--max-combinations-per-group: Limit combinations per group (for memory)--save-predictions: Export predicted structures as CIFs--pred-max-groups: Max test groups to generate predictions for (default: 10)--pred-max-per-group: Max predictions per group (default: 20)--pred-min-disp: Min displacement threshold for saving predictions (Å, default: 0.1)
- MACE embeddings are required for both stages; ensure
mace-torchis installed viapip install mace-torch - Supercell size trades off dataset coverage vs. runtime/memory. Start with
2 2 2for testing, use3 3 3or larger for production - max-distance controls neighbor grouping radius. Typical values: 3-5 Å
- The grouped dataset format is fully general: use any CIF structure and specify the mobile element
- Vacancy generation is deterministic (removes one atom at a time)
- Typical vacancy datasets have 1-3 unique displacement modes after clustering
- Higher
--mobility-threshold(e.g., 1.0 Å) filters out small relaxations, keeping only true hops
- Interstitial workflows use COMET to find low-energy sites (requires
comet-mlpackage) --energy-threshold(eV): Controls how far above the ground state to search (5.0 = good balance)--max-pair-distance(Å): Max distance between interstitial neighbors (3.0 recommended for AgCl)--min-neighbors: Warns if a site has sparse connectivity (suggests increasing--max-pair-distance)- Interstitial pairs are directed (A→B and B→A treated separately)
- Typical interstitial datasets have 4-8 unique displacement modes
- Recommended for real applications where both mechanisms may be active
- The
combined_workflow.shscript automatically:- Generates both defect types
- Merges datasets with proper group renumbering (avoids overwrites)
- Trains unified models on combined data
- Models trained on combined data generalize to both vacancy and interstitial transport
- Epochs: Mobility classifier converges quickly (50 epochs). Displacement models need more training (500 epochs recommended)
- Batch size: Mobility uses small batches (8) due to graph size. Displacement models can use larger batches (64)
- num_modes: Auto-estimation works well. For vacancies, 4-6 modes suffice. For interstitials, 8-12 captures diversity
- Displacement clustering: 0.8 Å threshold balances noise reduction vs. mode preservation
- Learning rate: 1e-3 works well with ReduceLROnPlateau scheduler (auto-reduces when validation plateaus)
- Large supercells (>4×4×4) may require
--max-combinations-per-groupto limit memory - COMET calculations are the bottleneck for interstitial generation; use
--max-calculationsto cap runtime - For smoke tests:
--supercell 2 2 2,--max-outcomes 20,--max-calculations 50
ORION tackles a fundamental challenge in atomic transport modeling: learning to predict multiple possible outcomes from a single initial configuration. Unlike regression to a single target, defect-mediated transport is inherently multi-modal:
- Vacancies: A single vacancy configuration has multiple neighboring atoms that could hop into the empty site, each representing a distinct valid outcome.
- Interstitials: An interstitial atom can relocate to several nearby low-energy sites, each defining a different final state.
- Multi-atom hops: Concerted motion (2- or 3-atom exchanges) further multiplies the number of possible outcomes.
Key challenges:
- One-to-many mapping: A single initial structure maps to many distinct final structures, not a single average displacement.
- Variable modality: Different configurations have different numbers of outcomes (e.g., vacancy with 1 clustered mode vs. interstitial with 5-6 distinct hop sites).
- Generalization: Models must work across different defect types, crystal structures, and coordination environments without hard-coded site information.
Training data format: Each training sample is a group: one initial structure paired with multiple final structures (outcomes), plus per-atom mobility labels derived from observed displacements. MACE embeddings encode local chemistry and coordination beyond nearest neighbors.
Example scenarios:
- Vacancy: AgCl supercell with one Ag vacancy. Nearby Ag atoms at different crystallographic positions can hop into the vacancy, creating ~3-5 distinct outcomes. After displacement clustering, these may reduce to 1-2 unique modes.
- Interstitial: AgCl supercell with one additional Ag atom at an interstitial site. COMET finds 5-6 low-energy neighboring interstitial sites within 3 Å, each representing a distinct hop outcome.
- Combined: Training on both vacancy and interstitial data simultaneously allows the model to handle both transport mechanisms in a single unified framework.
ORION uses a two-stage pipeline that explicitly handles multi-modal outcomes:
A graph neural network (MobilityClassifier) identifies which atoms are mobile in a given defect configuration:
- Input: Per-atom MACE embeddings (512-dim) + radius graph (edges within
cutoff, distance as edge feature) - Architecture: Several message-passing layers aggregate local coordination information
- Output: Per-atom mobility probability (0 = immobile, 1 = mobile)
- Loss: Focal loss to emphasize the rare mobile atoms
- Inference: Thresholding (typically 0.5) produces a binary mobility mask indicating which atoms could participate in a hop
Separate MLP models per hop size (1, 2, 3 atoms) predict multiple possible displacement outcomes using multi-head outputs:
Architecture:
- Input: Concatenated MACE embeddings + Fourier positional features for selected atom combinations
- Output:
num_modesdisplacement vectors (e.g., 12 heads × hop_size × 3 coordinates) - Each head represents one possible outcome mode
Handling variable modality (key innovation):
The challenge: training samples have variable numbers of target outcomes (T), but the model has a fixed number of output heads (M, typically 12).
Solution — Target replication strategy:
- Displacement clustering: Similar displacement vectors (within 0.8 Å threshold) are clustered and averaged, reducing noisy variation while preserving distinct modes.
- Target replication: If T < M, targets are replicated cyclically
[t0, t1, ..., t0, t1, ...]to fill all M heads. If T > M, only the first M targets are used. - Hungarian matching: During training, each predicted mode is matched to a (potentially replicated) target via bipartite matching with L2 distance cost.
- All modes trained: Every head learns from real targets every iteration—no "unused" heads that learn garbage.
Examples:
- Vacancy (T=1): One clustered displacement → replicated 12 times → all 12 heads learn the same correct displacement
- Interstitial (T=5): Five distinct hop sites → replicated
[t0,t1,t2,t3,t4,t0,t1,t2,t3,t4,t0,t1]→ heads learn valid sites with duplicates - At inference: All modes are valid (duplicates are expected and harmless)
Training details:
- Combinations: For each group, enumerate all combinations of mobile atoms (filtered by mobility threshold)
- PBC-aware displacements: Minimum-image convention with species-aware Hungarian matching for atom correspondence
- Loss: L2 (Euclidean) distance between matched prediction/target pairs
- Regularization: Unobserved combinations (no hop in data) are pushed toward zero displacement with reduced weight
- Learning rate schedule: ReduceLROnPlateau with patience=10, factor=0.5
Output: Models predict all num_modes displacements per atom combination. Can be exported as CIF files for visualization or downstream kinetic Monte Carlo sampling.
scripts/generate_data.py creates grouped datasets:
- Vacancy mode: Creates supercells with one vacancy, computes neighbors, stores initial + outcomes
- Interstitial mode: Uses COMET to find low-energy interstitial sites and pair-wise hops
- Both mode: Generates both datasets and optionally merges them with proper group renumbering
- Output: Each group contains
initial.cif,final_*.cif(multiple outcomes), andmobility_labels.pt
Nothing is hard-coded about vacancy positions, interstitial sites, or hop mechanisms:
- Mobility labels come from observed displacements (any atom with >threshold motion is flagged)
- MACE embeddings encode local chemistry and coordination, generalizing across defect types
- Multi-modal prediction with target replication handles variable outcome counts automatically
- Works for: Single-atom hops, multi-atom exchanges, vacancies, interstitials, or mixed datasets
After running the full pipeline, your output directory will contain:
output/
├── vacancy_data/ # (if vacancy mode used)
│ ├── train/
│ │ ├── group_0000/
│ │ │ ├── initial.cif
│ │ │ ├── final_0.cif
│ │ │ ├── final_1.cif
│ │ │ └── mobility_labels.pt
│ │ └── group_0001/ ...
│ └── test/ ...
├── interstitial_data/ # (if interstitial mode used)
│ └── [same structure]
├── combined_data/ # (if both mode used, after merging)
│ ├── train/
│ └── test/
├── mobility_[type]/
│ ├── best_model.pt # Best validation checkpoint
│ ├── final_model.pt # Final epoch checkpoint
│ └── latest.pt # Latest checkpoint with optimizer state
└── multi_hop_[type]/
├── hop_1/
│ ├── best_model.pt
│ ├── final_model.pt
│ ├── latest.pt
│ └── training_history.json
├── hop_2/ ...
├── hop_3/ ...
└── predictions/ # (if --save-predictions used)
├── group_0000/
│ ├── initial.cif
│ ├── pred_hop1_combo0000_mode0_idx11.cif
│ ├── pred_hop1_combo0001_mode1_idx23.cif
│ └── ...
└── group_0001/ ...
Mobility classifier:
- Loss: Focal loss (or BCE if
--no-focal), should decrease to ~0.1-0.3 - Accuracy: Overall per-atom accuracy, typically >95%
- Precision/Recall: More meaningful than accuracy due to class imbalance (few mobile atoms)
- Target: High recall (catch all mobile atoms) with acceptable precision
Displacement models:
- Loss: L2 distance between predicted and target displacements (Å)
- pos_mae: Mean absolute error for positive (mobile) combinations, target <0.01 Å
- neg_mae: Mean absolute error for negative (immobile) combinations, should be near zero
- Validation loss: Should track training loss; large gap suggests overfitting
Predicted CIF files show possible outcomes for test structures:
Naming convention: pred_hop{N}_combo{XXXX}_mode{M}_idx{atoms}.cif
hop{N}: Hop size (1, 2, or 3 atoms)combo{XXXX}: Combination indexmode{M}: Which output head (0 to num_modes-1)idx{atoms}: Atom indices involved (e.g.,idx11oridx11-23-45)
Expected patterns:
- Vacancies: Many modes will be near-identical (duplicates from target replication). Typical: 1-3 unique predictions.
- Interstitials: More diversity across modes, showing different hop sites. Typical: 4-8 unique predictions.
- Combined datasets: Mix of both patterns depending on the defect type in that group
Quality checks:
- Visualize predicted structures in a viewer (e.g., VESTA, Ovito)
- Verify displaced atoms land on valid lattice sites (not in interstitial regions or on wrong species)
- Check that predicted displacements are physically reasonable (typically 2-4 Å for nearest-neighbor hops)
- Compare predictions to ground-truth
final_*.ciffiles in test set
Mobility classifier:
- Low recall → Increase focal loss gamma, or use more training data
- High false positives → Decrease
--mobility-thresholdduring data generation - Poor generalization → Increase
--cutoffor--num-layersto capture more environment
Displacement models:
- High training loss (>0.1 Å) → Increase
--epochs, decrease--lr, or check data quality - Predictions on wrong sites → Likely insufficient training data or inappropriate
num_modes - All predictions identical → Check that displacement clustering threshold (0.8 Å) isn't too aggressive
- Validation loss plateaus → Normal with ReduceLROnPlateau; training will continue at reduced LR
Data quality:
- Few mobile atoms detected → Adjust
--mobility-thresholdor--max-distance - Interstitials have too few neighbors → Increase
--max-pair-distance - Training very slow → Reduce
--max-combinations-per-groupor use smaller supercells