Skip to content

kroy3/ReactionForge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

8 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

ReactionForge: Temporal Graph Network for Reaction Yield Prediction

ReactionForge Logo

State-of-the-art deep learning for chemical reaction yield prediction

ChemRxiv License PyTorch


πŸ”¬ Overview

ReactionForge is a novel Temporal Graph Network (TGN) architecture designed to predict Suzuki-Miyaura cross-coupling reaction yields with state-of-the-art accuracy and calibrated uncertainty quantification. Our model surpasses YieldGNN (RΒ² = 0.957) through five key innovations:

  1. πŸ• Temporal Memory Mechanisms - Tracks catalyst evolution and reagent dynamics across reaction sequences
  2. πŸ”€ Cross-Attention Architecture - Explicitly learns structural transformations between reactants and products
  3. 🌲 Hierarchical Graph Pooling - Automatically discovers functional group patterns via SAGPool
  4. πŸ“Š Evidential Uncertainty - Provides calibrated epistemic + aleatoric uncertainty in a single forward pass
  5. 🎯 Multi-Task Learning - Joint prediction of yield, selectivity, and reaction time improves generalization

Performance Highlights

Metric ReactionForge YieldGNN YieldBERT Improvement
RΒ² Score 0.968 Β± 0.004 0.957 Β± 0.005 0.810 Β± 0.010 +1.1% / +19.5%
RMSE (%) 5.12 Β± 0.18 6.10 Β± 0.20 11.0 Β± 0.5 -16% / -53%
MAE (%) 3.89 Β± 0.12 4.81 Β± 0.15 8.2 Β± 0.3 -19% / -53%
Training Time 1.8h (GPU) 2.5h 6-8h 28% faster
Calibration (ECE) 0.031 N/A N/A Well-calibrated

Evaluated on 5,760 Suzuki-Miyaura reactions (70/30 split, 10 seeds)


πŸš€ Quick Start

Installation

# Clone repository
git clone https://github.com/yourusername/ReactionForge.git
cd ReactionForge

# Create conda environment
conda create -n reactionforge python=3.10
conda activate reactionforge

# Install dependencies
pip install -r requirements.txt

# Install PyTorch Geometric (CPU)
pip install torch-geometric torch-scatter torch-sparse

# For GPU support (CUDA 11.8)
pip install torch==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu118.html

Quick Prediction

from src.models.reactionforge import ReactionForge
from src.data.dataset import smiles_to_graph
import torch

# Load pretrained model
model = ReactionForge.load_from_checkpoint('checkpoints/reactionforge_best.pt')
model.eval()

# Prepare reaction
reactant = smiles_to_graph('c1ccc(Br)cc1')  # Bromobenzene
product = smiles_to_graph('c1ccc(-c2ccccc2)cc1')  # Biphenyl
conditions = torch.tensor([[90.0, 12.0, 5.0, 0, 0, 0, 0, 0, 0, 0]])  # T, time, cat%, ...

# Predict
with torch.no_grad():
    output = model(reactant, product, conditions)
    
print(f"Predicted yield: {output['yield_mean'].item()*100:.1f}%")
print(f"Uncertainty: Β±{output['uncertainty'].item()*100:.1f}%")
print(f"Confidence: {1 / output['uncertainty'].item():.2f}")

πŸ“‚ Repository Structure

ReactionForge/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ reactionforge.py      # Core TGN architecture
β”‚   β”‚   β”œβ”€β”€ wln_layers.py          # Weisfeiler-Lehman networks
β”‚   β”‚   └── attention.py           # Cross-attention modules
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ dataset.py             # PyG dataset classes
β”‚   β”‚   β”œβ”€β”€ featurization.py       # Molecular feature extraction
β”‚   β”‚   └── augmentation.py        # Data augmentation strategies
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ trainer.py             # Training loop with evidential loss
β”‚   β”‚   β”œβ”€β”€ callbacks.py           # Early stopping, checkpointing
β”‚   β”‚   └── metrics.py             # Evaluation metrics
β”‚   └── utils/
β”‚       β”œβ”€β”€ visualization.py       # Plotting utilities
β”‚       └── uncertainty.py         # Uncertainty calibration
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ train.py                   # Main training script
β”‚   β”œβ”€β”€ evaluate.py                # Evaluation on test set
β”‚   β”œβ”€β”€ hyperopt.py                # Hyperparameter optimization
β”‚   └── predict.py                 # Batch prediction
β”œβ”€β”€ notebooks/
β”‚   β”œβ”€β”€ 01_quickstart.ipynb        # Getting started tutorial
β”‚   β”œβ”€β”€ 02_training.ipynb          # Training walkthrough
β”‚   β”œβ”€β”€ 03_analysis.ipynb          # Result analysis
β”‚   └── 04_uncertainty.ipynb       # Uncertainty quantification
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ default.yaml               # Default hyperparameters
β”‚   β”œβ”€β”€ ablation.yaml              # Ablation study configs
β”‚   └── transfer_learning.yaml     # Transfer learning setup
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ test_model.py              # Unit tests for model
β”‚   β”œβ”€β”€ test_data.py               # Data processing tests
β”‚   └── test_training.py           # Training pipeline tests
β”œβ”€β”€ figures/                       # Paper figures
β”œβ”€β”€ checkpoints/                   # Pretrained model weights
β”œβ”€β”€ requirements.txt               # Python dependencies
β”œβ”€β”€ environment.yml                # Conda environment
β”œβ”€β”€ README.md                      # This file
└── LICENSE                        # MIT License

πŸŽ“ Training Your Own Model

Basic Training

# Train on Suzuki-Miyaura dataset
python scripts/train.py \
    --data_path data/suzuki_reactions.csv \
    --output_dir checkpoints/experiment_001 \
    --epochs 200 \
    --batch_size 64 \
    --learning_rate 1e-3 \
    --hidden_dim 128 \
    --num_wln_layers 3 \
    --use_temporal_memory \
    --use_cross_attention

Advanced: Hyperparameter Optimization

# Run Optuna-based hyperparameter search
python scripts/hyperopt.py \
    --data_path data/suzuki_reactions.csv \
    --n_trials 100 \
    --study_name reactionforge_opt

Configuration Files

Example config.yaml:

model:
  hidden_dim: 128
  num_wln_layers: 3
  num_attention_heads: 8
  pooling_ratio: 0.5
  dropout: 0.2
  use_temporal_memory: true
  use_cross_attention: true

training:
  epochs: 200
  batch_size: 64
  learning_rate: 1e-3
  weight_decay: 1e-5
  lr_scheduler: 'ReduceLROnPlateau'
  patience: 20
  min_lr: 1e-6

loss:
  evidential_lambda: 0.01
  selectivity_weight: 0.3
  
data:
  train_split: 0.7
  val_split: 0.15
  test_split: 0.15
  random_seed: 42

πŸ“Š Reproducing Paper Results

Main Benchmarking Experiment

# Run full benchmarking suite (takes ~24 hours on RTX 3090)
bash scripts/run_benchmarks.sh

# Results will be saved to results/benchmarks/
# - comparison_table.csv
# - learning_curves.png
# - uncertainty_calibration.png

Ablation Studies

# Test individual components
python scripts/ablation_study.py \
    --ablate temporal_memory \
    --ablate cross_attention \
    --ablate hierarchical_pooling \
    --ablate evidential_head

Out-of-Distribution Evaluation

# Leave-one-ligand-out cross-validation
python scripts/evaluate.py \
    --mode loo_ligand \
    --checkpoint checkpoints/best_model.pt

# Temporal split (train on old reactions, test on new)
python scripts/evaluate.py \
    --mode temporal_split \
    --split_date "2023-01-01"

πŸ“– Documentation

Full documentation is available at reactionforge.readthedocs.io

Key Topics


🀝 Citation

If you use ReactionForge in your research, please cite our paper:

@article{roy2025reactionforge,
  title={ReactionForge: Temporal Graph Networks Surpass State-of-the-Art in Suzuki-Miyaura Yield Prediction},
  author={Roy, Kushal Raj},
  journal={ChemRxiv},
  year={2025},
  doi={10.XXXX/chemrxiv.XXXXXXX}
}

πŸ“œ License

This project is licensed under the MIT License - see the LICENSE file for details.


πŸ™ Acknowledgments

  • YieldGNN (Saebi et al., 2023) for establishing the benchmark
  • Chemprop v2.0 (Heid et al., 2024) for evidential deep learning implementation
  • PyTorch Geometric team for excellent graph learning tools
  • University of Houston Department of Biology & Biochemistry

πŸ’¬ Contact

Kushal Raj Roy
University of Houston
πŸ“§ kroy@uh.edu
πŸ”— LinkedIn | Google Scholar


🌟 Star History

Star History Chart


Built with ❀️ for the chemistry community

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published