Skip to content

coleygroup/QM-augmented_GNN

Repository files navigation

QM-augmented_GNN

This repository contains the main code associated with the project Quantum chemistry augmented neural networks for reactivity prediction: Performance, generalizability and explainability. Note that each of the graph neural networks presented here are adaptations of the original models developed by Yanfei Guan and co-workers. For more information, see the repository reactivity_predictions_substitution.

Requirements

  1. python 3.7
  2. tensorflow 2.0.0
  3. rdkit
  4. qmdesc (python package for predicting QM descriptors on the fly)

Conda environment

To set up a conda environment:

conda env create --name <env-name> --file environment.yml

Data

Curated data sets have been included in each of the main directories in a format that is compatible with the respective models (cf the datasets directories). In the case of regression_e2_sn2, data points are formatted as follows:

,reaction_id,smiles,reaction_core,activation_energy
0,0,[NH2:1][C@H:2]([C@H:3]([Cl:4])[N+:5](=[O:6])[O-:7])[N+:8](=[O:9])[O-:10].[H-:11],"[[3, 2, 1, 10]]",1.886101230187571
1,1,[CH3:1][CH2:2][C@:3]([NH2:4])([F:5])[N+:6](=[O:7])[O-:8].[H-:9],"[[4, 2, 1, 8]]",16.259824710850626

where smiles corresponds to the reactant smiles and reaction_core indicates the index of the sites/heavy atoms undergoing a change in their bonding situation throughout the reaction (indexing starts from 0). Note that the numbering of the reactant smiles has to be ordered (i.e., the atom at index 0 carries number 1, the atom at index 1 carries number 2 etc.).

In the case of classification_e2_sn2, data points are formatted in a similar manner, but now the two reaction cores for the competing reaction pathways are included in a single data point, with the pathway corresponding to the lowest-energy transition state listed first:

,reaction_id,smiles,products_run
0,0,[N:1]#[C:2][C@@H:3]([NH2:4])[CH2:5][Br:6].[H-:7],"[[5, 4, 6], [5, 4, 2, 6]]"
1,1,[CH3:1][C@@H:2]([NH2:3])[CH2:4][Cl:5].[F-:6],"[[4, 3, 5], [4, 3, 1, 5]]"

In the case of classification_aromatic_substitution, data points are formatted as:

,reaction_id,rxn_smiles,products_run,PatentNumber
0,86,[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[c:9]([cH:10]1)[O:11][CH2:12][CH2:13][O:14]2.[N+:15]([O-:16])([OH:17])=[O:18]>C(O)(=O)C>[CH3:1][O:2][C:3](=[O:4])[c:5]1[c:6]([N+:15]([O-:17])=[O:18])[cH:7][c:8]2[c:9]([cH:10]1)[O:11][CH2:12][CH2:13][O:14]2,[CH3:1][O:2][C:3](=[O:4])[c:5]1[c:6]([N+:15]([O-:17])=[O:18])[cH:7][c:8]2[c:9]([cH:10]1)[O:11][CH2:12][CH2:13][O:14]2.[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[c:9]([c:10]1[N+:15]([O-:17])=[O:18])[O:11][CH2:12][CH2:13][O:14]2.[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][c:7]([N+:15]([O-:17])=[O:18])[c:8]2[c:9]([cH:10]1)[O:11][CH2:12][CH2:13][O:14]2,US03931179
1,126,[F:1][c:2]1[cH:3][cH:4][c:5]([C:6]([CH2:7][CH2:8][C:9]([OH:10])=[O:11])=[O:12])[cH:13][cH:14]1.[N+:15]([O-:16])([OH:17])=[O:18]>>[F:1][c:2]1[cH:3][cH:4][c:5]([C:6]([CH2:7][CH2:8][C:9]([OH:10])=[O:11])=[O:12])[cH:13][c:14]1[N+:15]([O-:17])=[O:18],[F:1][c:2]1[cH:3][cH:4][c:5]([C:6]([CH2:7][CH2:8][C:9]([OH:10])=[O:11])=[O:12])[cH:13][c:14]1[N+:15]([O-:17])=[O:18].[F:1][c:2]1[cH:3][cH:4][c:5]([C:6]([CH2:7][CH2:8][C:9]([OH:10])=[O:11])=[O:12])[c:13]([N+:15]([O-:17])=[O:18])[cH:14]1,US03931177

in which, rxn_smiles are the full reaction SMILES and products_run are the potential products (major.minor1.minor2.....).

Training

This repository contains three main directories, each providing two distinct graph neural network models (GNN and ml-QM-GNN), tailored to the considered data set and task, as described in the paper.

GNN

Conventional graph neural networks that rely only on the machine learned reaction representation of a given reaction. To train the model, run:

python reactivitiy.py -m GNN --data_path <path to the .csv file> --model_dir <directory to save the trained model> 

For example, to train the model on the E2/SN2 data set for barrier height prediction (cf the "regression_e2_sn2" directory):

python reactivitiy.py -m GNN --data_path datasets/e2_sn2_regression.csv --model_dir trained_model/GNN_e2_sn2

A checkpoint file, best_model.hdf5, will be saved in the trained_model/GNN_e2_sn2 directory.

ml-QM-GNN

These are the fusion models, which combine machine learned reaction representation and on-the-fly calculated QM descriptors. To use this architecture, the Chemprop-atom-bond must be installed. To train the model, run:

python reactivitiy.py --data_path <path to the .csv file> --model_dir <directory to save the trained model> 

The reactivity.py use ml-QM-GNN mode by default. The workflow first predict QM atomic/bond descriptors for all reactants found in the reactions. The predicted descriptors are then scaled through a min-max scaler. A dictionary containing scikit-learn scaler object will be saved as scalers.pickle in the model_dir for later predicting tasks. A checkpoint file, best_model.hdf5 will also be saved in the model_dir

For example:

python reactivitiy.py -m ml_QM_GNN --data_path datasets/e2_sn2_regression.csv --model_dir trained_model/GNN_e2_sn2

Predicting

To use the trained model, run:

python reactivitiy -m <mode> --data_path <path to the predicting .csv file> --model_dir <directory containing the trained model> -p 

where data_path is the path to the data .csv file, whose format has been discussed above. model_dir is the directory holding the trained model. The model must be named as best_model.hdf5 and stores parameters only. The model_dir must also include a scalers.pickle under ml_QM_GNN mode as discussed in the training section.

Cross-validating

To perform a cross-validation, run:

python cross_val.py -m <mode> --data_path <path to the predicting .csv file> --model_dir <directory containing the trained model> --k_fold <number of folds> --sample <number of training points>

where data_path is the path to the data .csv file, whose format has been discussed above. model_dir is the directory holding the trained model, and sample (optional) is the number of traing points to be sampled from the "original training set", selected for each fold.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages