State-of-the-art deep learning for chemical reaction yield prediction
ReactionForge is a novel Temporal Graph Network (TGN) architecture designed to predict Suzuki-Miyaura cross-coupling reaction yields with state-of-the-art accuracy and calibrated uncertainty quantification. Our model surpasses YieldGNN (RΒ² = 0.957) through five key innovations:
- π Temporal Memory Mechanisms - Tracks catalyst evolution and reagent dynamics across reaction sequences
- π Cross-Attention Architecture - Explicitly learns structural transformations between reactants and products
- π² Hierarchical Graph Pooling - Automatically discovers functional group patterns via SAGPool
- π Evidential Uncertainty - Provides calibrated epistemic + aleatoric uncertainty in a single forward pass
- π― Multi-Task Learning - Joint prediction of yield, selectivity, and reaction time improves generalization
| Metric | ReactionForge | YieldGNN | YieldBERT | Improvement |
|---|---|---|---|---|
| RΒ² Score | 0.968 Β± 0.004 | 0.957 Β± 0.005 | 0.810 Β± 0.010 | +1.1% / +19.5% |
| RMSE (%) | 5.12 Β± 0.18 | 6.10 Β± 0.20 | 11.0 Β± 0.5 | -16% / -53% |
| MAE (%) | 3.89 Β± 0.12 | 4.81 Β± 0.15 | 8.2 Β± 0.3 | -19% / -53% |
| Training Time | 1.8h (GPU) | 2.5h | 6-8h | 28% faster |
| Calibration (ECE) | 0.031 | N/A | N/A | Well-calibrated |
Evaluated on 5,760 Suzuki-Miyaura reactions (70/30 split, 10 seeds)
# Clone repository
git clone https://github.com/yourusername/ReactionForge.git
cd ReactionForge
# Create conda environment
conda create -n reactionforge python=3.10
conda activate reactionforge
# Install dependencies
pip install -r requirements.txt
# Install PyTorch Geometric (CPU)
pip install torch-geometric torch-scatter torch-sparse
# For GPU support (CUDA 11.8)
pip install torch==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu118.htmlfrom src.models.reactionforge import ReactionForge
from src.data.dataset import smiles_to_graph
import torch
# Load pretrained model
model = ReactionForge.load_from_checkpoint('checkpoints/reactionforge_best.pt')
model.eval()
# Prepare reaction
reactant = smiles_to_graph('c1ccc(Br)cc1') # Bromobenzene
product = smiles_to_graph('c1ccc(-c2ccccc2)cc1') # Biphenyl
conditions = torch.tensor([[90.0, 12.0, 5.0, 0, 0, 0, 0, 0, 0, 0]]) # T, time, cat%, ...
# Predict
with torch.no_grad():
output = model(reactant, product, conditions)
print(f"Predicted yield: {output['yield_mean'].item()*100:.1f}%")
print(f"Uncertainty: Β±{output['uncertainty'].item()*100:.1f}%")
print(f"Confidence: {1 / output['uncertainty'].item():.2f}")ReactionForge/
βββ src/
β βββ models/
β β βββ reactionforge.py # Core TGN architecture
β β βββ wln_layers.py # Weisfeiler-Lehman networks
β β βββ attention.py # Cross-attention modules
β βββ data/
β β βββ dataset.py # PyG dataset classes
β β βββ featurization.py # Molecular feature extraction
β β βββ augmentation.py # Data augmentation strategies
β βββ training/
β β βββ trainer.py # Training loop with evidential loss
β β βββ callbacks.py # Early stopping, checkpointing
β β βββ metrics.py # Evaluation metrics
β βββ utils/
β βββ visualization.py # Plotting utilities
β βββ uncertainty.py # Uncertainty calibration
βββ scripts/
β βββ train.py # Main training script
β βββ evaluate.py # Evaluation on test set
β βββ hyperopt.py # Hyperparameter optimization
β βββ predict.py # Batch prediction
βββ notebooks/
β βββ 01_quickstart.ipynb # Getting started tutorial
β βββ 02_training.ipynb # Training walkthrough
β βββ 03_analysis.ipynb # Result analysis
β βββ 04_uncertainty.ipynb # Uncertainty quantification
βββ configs/
β βββ default.yaml # Default hyperparameters
β βββ ablation.yaml # Ablation study configs
β βββ transfer_learning.yaml # Transfer learning setup
βββ tests/
β βββ test_model.py # Unit tests for model
β βββ test_data.py # Data processing tests
β βββ test_training.py # Training pipeline tests
βββ figures/ # Paper figures
βββ checkpoints/ # Pretrained model weights
βββ requirements.txt # Python dependencies
βββ environment.yml # Conda environment
βββ README.md # This file
βββ LICENSE # MIT License
# Train on Suzuki-Miyaura dataset
python scripts/train.py \
--data_path data/suzuki_reactions.csv \
--output_dir checkpoints/experiment_001 \
--epochs 200 \
--batch_size 64 \
--learning_rate 1e-3 \
--hidden_dim 128 \
--num_wln_layers 3 \
--use_temporal_memory \
--use_cross_attention# Run Optuna-based hyperparameter search
python scripts/hyperopt.py \
--data_path data/suzuki_reactions.csv \
--n_trials 100 \
--study_name reactionforge_optExample config.yaml:
model:
hidden_dim: 128
num_wln_layers: 3
num_attention_heads: 8
pooling_ratio: 0.5
dropout: 0.2
use_temporal_memory: true
use_cross_attention: true
training:
epochs: 200
batch_size: 64
learning_rate: 1e-3
weight_decay: 1e-5
lr_scheduler: 'ReduceLROnPlateau'
patience: 20
min_lr: 1e-6
loss:
evidential_lambda: 0.01
selectivity_weight: 0.3
data:
train_split: 0.7
val_split: 0.15
test_split: 0.15
random_seed: 42# Run full benchmarking suite (takes ~24 hours on RTX 3090)
bash scripts/run_benchmarks.sh
# Results will be saved to results/benchmarks/
# - comparison_table.csv
# - learning_curves.png
# - uncertainty_calibration.png# Test individual components
python scripts/ablation_study.py \
--ablate temporal_memory \
--ablate cross_attention \
--ablate hierarchical_pooling \
--ablate evidential_head# Leave-one-ligand-out cross-validation
python scripts/evaluate.py \
--mode loo_ligand \
--checkpoint checkpoints/best_model.pt
# Temporal split (train on old reactions, test on new)
python scripts/evaluate.py \
--mode temporal_split \
--split_date "2023-01-01"Full documentation is available at reactionforge.readthedocs.io
- Architecture Details - Deep dive into model components
- Data Preparation - How to format your own datasets
- Training Guide - Best practices for training
- API Reference - Complete API documentation
- FAQ - Frequently asked questions
If you use ReactionForge in your research, please cite our paper:
@article{roy2025reactionforge,
title={ReactionForge: Temporal Graph Networks Surpass State-of-the-Art in Suzuki-Miyaura Yield Prediction},
author={Roy, Kushal Raj},
journal={ChemRxiv},
year={2025},
doi={10.XXXX/chemrxiv.XXXXXXX}
}This project is licensed under the MIT License - see the LICENSE file for details.
- YieldGNN (Saebi et al., 2023) for establishing the benchmark
- Chemprop v2.0 (Heid et al., 2024) for evidential deep learning implementation
- PyTorch Geometric team for excellent graph learning tools
- University of Houston Department of Biology & Biochemistry
Kushal Raj Roy
University of Houston
π§ kroy@uh.edu
π LinkedIn | Google Scholar
Built with β€οΈ for the chemistry community
