Skip to content

njszym/orion

Repository files navigation

ORION: Optimal Rare-event Inference for Atomic Transitions

ORION provides an end-to-end machine learning pipeline for predicting atomic transport in crystalline materials:

  1. Generate grouped defect data with mobility labels (vacancies, interstitials, or both),
  2. Train a mobility classifier to identify which atoms can hop in a given structure, and
  3. Train multi-modal displacement models (1-, 2-, 3-atom) that predict multiple possible outcomes for each atomic configuration.

The pipeline handles the inherent multi-modality of defect transport: a single initial structure can lead to many distinct valid final states (e.g., multiple vacancy hop sites, different interstitial positions). The displacement models use multi-head outputs to capture this diversity, and a target replication strategy ensures all model heads learn valid displacements.

The repository ships with an AgCl example structure, but the workflow is fully general and applies to any crystalline material once a structure file is provided.

Quickstart

Installation:

pip install -e .

Workflow 1: Vacancy-Only Training

Train models specifically for vacancy-mediated transport:

# 1) Generate vacancy data (train/test split)
python scripts/generate_data.py \
  --structure data/AgCl.cif \
  --element Ag \
  --supercell 2 2 2 \
  --defect-type vacancy \
  --max-distance 4.0 \
  --output-dir output

# 2) Train mobility classifier
python scripts/train_mobility.py \
  --data-dir output/vacancy_data \
  --output-dir output/mobility_vacancy \
  --epochs 50 \
  --batch-size 8 \
  --lr 1e-3

# 3) Train multi-hop displacement models
python scripts/train_multi_hop_models.py \
  --data-dir output/vacancy_data \
  --output-dir output/multi_hop_vacancy \
  --epochs 500 \
  --batch-size 64 \
  --mobility-threshold 0.5 \
  --save-predictions \
  --pred-max-groups 10 \
  --pred-max-per-group 20

Workflow 2: Interstitial-Only Training

Train models for interstitial-mediated transport (defaults tuned to keep ~6 neighbors within ~3 Å):

# 1) Generate interstitial data
python scripts/generate_data.py \
  --structure data/AgCl.cif \
  --element Ag \
  --supercell 2 2 2 \
  --defect-type interstitial \
  --energy-threshold 5.0 \
  --max-pair-distance 3.0 \
  --max-calculations 200 \
  --min-neighbors 6 \
  --max-distance 4.0 \
  --output-dir output

# 2) Train mobility classifier
python scripts/train_mobility.py \
  --data-dir output/interstitial_data \
  --output-dir output/mobility_interstitial \
  --epochs 50 \
  --batch-size 8 \
  --lr 1e-3

# 3) Train multi-hop displacement models
python scripts/train_multi_hop_models.py \
  --data-dir output/interstitial_data \
  --output-dir output/multi_hop_interstitial \
  --epochs 500 \
  --batch-size 64 \
  --mobility-threshold 0.5 \
  --save-predictions \
  --pred-max-groups 10 \
  --pred-max-per-group 20

Workflow 3: Combined Vacancy + Interstitial Training

Train a unified model on both defect types simultaneously. This is the recommended approach for materials where both mechanisms may be active:

# Use the combined_workflow.sh script (generates data, trains mobility, trains displacement)
bash combined_workflow.sh

# Or run manually:
# 1) Generate both defect types
python scripts/generate_data.py \
  --structure data/AgCl.cif \
  --element Ag \
  --supercell 2 2 2 \
  --defect-type both \
  --energy-threshold 5.0 \
  --max-pair-distance 3.0 \
  --max-calculations 200 \
  --min-neighbors 6 \
  --max-distance 4.0 \
  --output-dir output

# (Script automatically merges output/vacancy_data and output/interstitial_data into output/combined_data)

# 2) Train mobility classifier on combined data
python scripts/train_mobility.py \
  --data-dir output/combined_data \
  --output-dir output/mobility_combined \
  --epochs 50 \
  --batch-size 8 \
  --lr 1e-3

# 3) Train multi-hop models on combined data
python scripts/train_multi_hop_models.py \
  --data-dir output/combined_data \
  --output-dir output/multi_hop_combined \
  --epochs 500 \
  --batch-size 64 \
  --mobility-threshold 0.5 \
  --save-predictions \
  --pred-max-groups 10 \
  --pred-max-per-group 20 \
  --pred-min-disp 0.1

Generated artifacts live under the chosen --output-dir values:

  • vacancy_data/train|test/group_xxxx/: initial.cif, final_*.cif, mobility_labels.pt
  • mobility/: mobility checkpoints (best_model.pt, final_model.pt) and training logs
  • multi_hop/: hop-specific checkpoints under hop_1/, hop_2/, hop_3/, plus optional predicted CIFs under predictions/

Scripts at a Glance

  • scripts/generate_data.py: Generate grouped vacancy or interstitial data (COMET-backed), label mobility, and split train/test. Use --random-sampling to fall back to the older vacancy generator; omit --num-pairs to keep all valid interstitial pairs; use --min-neighbors to warn on sparse neighbor counts per interstitial site.
  • scripts/train_mobility.py: GNN-based per-atom classifier that ingests MACE embeddings and predicts mobility masks.
  • scripts/train_multi_hop_models.py: Trains separate 1-, 2-, and 3-atom displacement heads and optionally exports predicted CIFs.

Key Arguments

Data Generation (generate_data.py)

  • --structure: Path to input CIF file (e.g., data/AgCl.cif)
  • --element: Mobile element symbol (e.g., Ag)
  • --supercell: Supercell size (e.g., 2 2 2)
  • --defect-type: One of vacancy, interstitial, or both
  • --max-distance: Max neighbor distance for grouping (Å)
  • --mobility-threshold: Min displacement to label atom as mobile (Å, default: 0.5)
  • --train-fraction: Train/test split ratio (default: 0.8)
  • --energy-threshold: Max energy above ground state for interstitials (eV, default: 5.0)
  • --max-pair-distance: Max distance between interstitial pairs (Å, default: 3.0)
  • --max-calculations: Max COMET calculations per interstitial config (default: 200)
  • --min-neighbors: Warn if interstitial site has fewer neighbors (default: 6)
  • --output-dir: Output directory root (creates subdirs for each defect type)

Mobility Training (train_mobility.py)

  • --data-dir: Path to grouped dataset (e.g., output/vacancy_data)
  • --output-dir: Where to save mobility model checkpoints
  • --cutoff: Graph edge cutoff radius (Å, default: 5.0)
  • --hidden-dim: GNN hidden dimension (default: 128)
  • --num-layers: Number of message-passing layers (default: 3)
  • --batch-size: Training batch size (default: 8)
  • --epochs: Number of training epochs (default: 50)
  • --lr: Learning rate (default: 1e-3)
  • --no-focal: Disable focal loss (use BCE instead)

Multi-Hop Training (train_multi_hop_models.py)

  • --data-dir: Path to grouped dataset
  • --output-dir: Where to save displacement model checkpoints
  • --hidden-dim: MLP hidden dimension (default: 256)
  • --num-layers: Number of MLP layers (default: 4)
  • --batch-size: Training batch size (default: 64)
  • --epochs: Number of training epochs (default: 500)
  • --lr: Learning rate (default: 1e-3)
  • --mobility-threshold: Min displacement to consider atom mobile (Å, default: 0.5)
  • --num-modes: Number of output heads (auto-estimated if not provided)
  • --max-modes: Ceiling for auto-estimation (default: 12)
  • --max-combinations-per-group: Limit combinations per group (for memory)
  • --save-predictions: Export predicted structures as CIFs
  • --pred-max-groups: Max test groups to generate predictions for (default: 10)
  • --pred-max-per-group: Max predictions per group (default: 20)
  • --pred-min-disp: Min displacement threshold for saving predictions (Å, default: 0.1)

Notes and Customization

General

  • MACE embeddings are required for both stages; ensure mace-torch is installed via pip install mace-torch
  • Supercell size trades off dataset coverage vs. runtime/memory. Start with 2 2 2 for testing, use 3 3 3 or larger for production
  • max-distance controls neighbor grouping radius. Typical values: 3-5 Å
  • The grouped dataset format is fully general: use any CIF structure and specify the mobile element

Vacancy-Specific

  • Vacancy generation is deterministic (removes one atom at a time)
  • Typical vacancy datasets have 1-3 unique displacement modes after clustering
  • Higher --mobility-threshold (e.g., 1.0 Å) filters out small relaxations, keeping only true hops

Interstitial-Specific

  • Interstitial workflows use COMET to find low-energy sites (requires comet-ml package)
  • --energy-threshold (eV): Controls how far above the ground state to search (5.0 = good balance)
  • --max-pair-distance (Å): Max distance between interstitial neighbors (3.0 recommended for AgCl)
  • --min-neighbors: Warns if a site has sparse connectivity (suggests increasing --max-pair-distance)
  • Interstitial pairs are directed (A→B and B→A treated separately)
  • Typical interstitial datasets have 4-8 unique displacement modes

Combined Workflows

  • Recommended for real applications where both mechanisms may be active
  • The combined_workflow.sh script automatically:
    1. Generates both defect types
    2. Merges datasets with proper group renumbering (avoids overwrites)
    3. Trains unified models on combined data
  • Models trained on combined data generalize to both vacancy and interstitial transport

Training Tips

  • Epochs: Mobility classifier converges quickly (50 epochs). Displacement models need more training (500 epochs recommended)
  • Batch size: Mobility uses small batches (8) due to graph size. Displacement models can use larger batches (64)
  • num_modes: Auto-estimation works well. For vacancies, 4-6 modes suffice. For interstitials, 8-12 captures diversity
  • Displacement clustering: 0.8 Å threshold balances noise reduction vs. mode preservation
  • Learning rate: 1e-3 works well with ReduceLROnPlateau scheduler (auto-reduces when validation plateaus)

Memory and Speed

  • Large supercells (>4×4×4) may require --max-combinations-per-group to limit memory
  • COMET calculations are the bottleneck for interstitial generation; use --max-calculations to cap runtime
  • For smoke tests: --supercell 2 2 2, --max-outcomes 20, --max-calculations 50

Problem Setting

ORION tackles a fundamental challenge in atomic transport modeling: learning to predict multiple possible outcomes from a single initial configuration. Unlike regression to a single target, defect-mediated transport is inherently multi-modal:

  • Vacancies: A single vacancy configuration has multiple neighboring atoms that could hop into the empty site, each representing a distinct valid outcome.
  • Interstitials: An interstitial atom can relocate to several nearby low-energy sites, each defining a different final state.
  • Multi-atom hops: Concerted motion (2- or 3-atom exchanges) further multiplies the number of possible outcomes.

Key challenges:

  1. One-to-many mapping: A single initial structure maps to many distinct final structures, not a single average displacement.
  2. Variable modality: Different configurations have different numbers of outcomes (e.g., vacancy with 1 clustered mode vs. interstitial with 5-6 distinct hop sites).
  3. Generalization: Models must work across different defect types, crystal structures, and coordination environments without hard-coded site information.

Training data format: Each training sample is a group: one initial structure paired with multiple final structures (outcomes), plus per-atom mobility labels derived from observed displacements. MACE embeddings encode local chemistry and coordination beyond nearest neighbors.

Example scenarios:

  • Vacancy: AgCl supercell with one Ag vacancy. Nearby Ag atoms at different crystallographic positions can hop into the vacancy, creating ~3-5 distinct outcomes. After displacement clustering, these may reduce to 1-2 unique modes.
  • Interstitial: AgCl supercell with one additional Ag atom at an interstitial site. COMET finds 5-6 low-energy neighboring interstitial sites within 3 Å, each representing a distinct hop outcome.
  • Combined: Training on both vacancy and interstitial data simultaneously allows the model to handle both transport mechanisms in a single unified framework.

How ORION Solves It

ORION uses a two-stage pipeline that explicitly handles multi-modal outcomes:

Stage 1: Mobility Classification

A graph neural network (MobilityClassifier) identifies which atoms are mobile in a given defect configuration:

  • Input: Per-atom MACE embeddings (512-dim) + radius graph (edges within cutoff, distance as edge feature)
  • Architecture: Several message-passing layers aggregate local coordination information
  • Output: Per-atom mobility probability (0 = immobile, 1 = mobile)
  • Loss: Focal loss to emphasize the rare mobile atoms
  • Inference: Thresholding (typically 0.5) produces a binary mobility mask indicating which atoms could participate in a hop

Stage 2: Multi-Modal Displacement Prediction

Separate MLP models per hop size (1, 2, 3 atoms) predict multiple possible displacement outcomes using multi-head outputs:

Architecture:

  • Input: Concatenated MACE embeddings + Fourier positional features for selected atom combinations
  • Output: num_modes displacement vectors (e.g., 12 heads × hop_size × 3 coordinates)
  • Each head represents one possible outcome mode

Handling variable modality (key innovation):

The challenge: training samples have variable numbers of target outcomes (T), but the model has a fixed number of output heads (M, typically 12).

Solution — Target replication strategy:

  1. Displacement clustering: Similar displacement vectors (within 0.8 Å threshold) are clustered and averaged, reducing noisy variation while preserving distinct modes.
  2. Target replication: If T < M, targets are replicated cyclically [t0, t1, ..., t0, t1, ...] to fill all M heads. If T > M, only the first M targets are used.
  3. Hungarian matching: During training, each predicted mode is matched to a (potentially replicated) target via bipartite matching with L2 distance cost.
  4. All modes trained: Every head learns from real targets every iteration—no "unused" heads that learn garbage.

Examples:

  • Vacancy (T=1): One clustered displacement → replicated 12 times → all 12 heads learn the same correct displacement
  • Interstitial (T=5): Five distinct hop sites → replicated [t0,t1,t2,t3,t4,t0,t1,t2,t3,t4,t0,t1] → heads learn valid sites with duplicates
  • At inference: All modes are valid (duplicates are expected and harmless)

Training details:

  • Combinations: For each group, enumerate all combinations of mobile atoms (filtered by mobility threshold)
  • PBC-aware displacements: Minimum-image convention with species-aware Hungarian matching for atom correspondence
  • Loss: L2 (Euclidean) distance between matched prediction/target pairs
  • Regularization: Unobserved combinations (no hop in data) are pushed toward zero displacement with reduced weight
  • Learning rate schedule: ReduceLROnPlateau with patience=10, factor=0.5

Output: Models predict all num_modes displacements per atom combination. Can be exported as CIF files for visualization or downstream kinetic Monte Carlo sampling.

Data Generation

scripts/generate_data.py creates grouped datasets:

  • Vacancy mode: Creates supercells with one vacancy, computes neighbors, stores initial + outcomes
  • Interstitial mode: Uses COMET to find low-energy interstitial sites and pair-wise hops
  • Both mode: Generates both datasets and optionally merges them with proper group renumbering
  • Output: Each group contains initial.cif, final_*.cif (multiple outcomes), and mobility_labels.pt

Why This is General

Nothing is hard-coded about vacancy positions, interstitial sites, or hop mechanisms:

  • Mobility labels come from observed displacements (any atom with >threshold motion is flagged)
  • MACE embeddings encode local chemistry and coordination, generalizing across defect types
  • Multi-modal prediction with target replication handles variable outcome counts automatically
  • Works for: Single-atom hops, multi-atom exchanges, vacancies, interstitials, or mixed datasets

Output Artifacts and Interpretation

Directory Structure

After running the full pipeline, your output directory will contain:

output/
├── vacancy_data/           # (if vacancy mode used)
│   ├── train/
│   │   ├── group_0000/
│   │   │   ├── initial.cif
│   │   │   ├── final_0.cif
│   │   │   ├── final_1.cif
│   │   │   └── mobility_labels.pt
│   │   └── group_0001/ ...
│   └── test/ ...
├── interstitial_data/      # (if interstitial mode used)
│   └── [same structure]
├── combined_data/          # (if both mode used, after merging)
│   ├── train/
│   └── test/
├── mobility_[type]/
│   ├── best_model.pt       # Best validation checkpoint
│   ├── final_model.pt      # Final epoch checkpoint
│   └── latest.pt           # Latest checkpoint with optimizer state
└── multi_hop_[type]/
    ├── hop_1/
    │   ├── best_model.pt
    │   ├── final_model.pt
    │   ├── latest.pt
    │   └── training_history.json
    ├── hop_2/ ...
    ├── hop_3/ ...
    └── predictions/        # (if --save-predictions used)
        ├── group_0000/
        │   ├── initial.cif
        │   ├── pred_hop1_combo0000_mode0_idx11.cif
        │   ├── pred_hop1_combo0001_mode1_idx23.cif
        │   └── ...
        └── group_0001/ ...

Interpreting Training Logs

Mobility classifier:

  • Loss: Focal loss (or BCE if --no-focal), should decrease to ~0.1-0.3
  • Accuracy: Overall per-atom accuracy, typically >95%
  • Precision/Recall: More meaningful than accuracy due to class imbalance (few mobile atoms)
  • Target: High recall (catch all mobile atoms) with acceptable precision

Displacement models:

  • Loss: L2 distance between predicted and target displacements (Å)
  • pos_mae: Mean absolute error for positive (mobile) combinations, target <0.01 Å
  • neg_mae: Mean absolute error for negative (immobile) combinations, should be near zero
  • Validation loss: Should track training loss; large gap suggests overfitting

Interpreting Predictions

Predicted CIF files show possible outcomes for test structures:

Naming convention: pred_hop{N}_combo{XXXX}_mode{M}_idx{atoms}.cif

  • hop{N}: Hop size (1, 2, or 3 atoms)
  • combo{XXXX}: Combination index
  • mode{M}: Which output head (0 to num_modes-1)
  • idx{atoms}: Atom indices involved (e.g., idx11 or idx11-23-45)

Expected patterns:

  • Vacancies: Many modes will be near-identical (duplicates from target replication). Typical: 1-3 unique predictions.
  • Interstitials: More diversity across modes, showing different hop sites. Typical: 4-8 unique predictions.
  • Combined datasets: Mix of both patterns depending on the defect type in that group

Quality checks:

  1. Visualize predicted structures in a viewer (e.g., VESTA, Ovito)
  2. Verify displaced atoms land on valid lattice sites (not in interstitial regions or on wrong species)
  3. Check that predicted displacements are physically reasonable (typically 2-4 Å for nearest-neighbor hops)
  4. Compare predictions to ground-truth final_*.cif files in test set

Troubleshooting

Mobility classifier:

  • Low recall → Increase focal loss gamma, or use more training data
  • High false positives → Decrease --mobility-threshold during data generation
  • Poor generalization → Increase --cutoff or --num-layers to capture more environment

Displacement models:

  • High training loss (>0.1 Å) → Increase --epochs, decrease --lr, or check data quality
  • Predictions on wrong sites → Likely insufficient training data or inappropriate num_modes
  • All predictions identical → Check that displacement clustering threshold (0.8 Å) isn't too aggressive
  • Validation loss plateaus → Normal with ReduceLROnPlateau; training will continue at reduced LR

Data quality:

  • Few mobile atoms detected → Adjust --mobility-threshold or --max-distance
  • Interstitials have too few neighbors → Increase --max-pair-distance
  • Training very slow → Reduce --max-combinations-per-group or use smaller supercells

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published