Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baselines and evaluation scripts #248

Merged
merged 25 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8396762
add geom stats script
AlexandraVolokhova Sep 21, 2023
2879ae1
add comment
AlexandraVolokhova Sep 21, 2023
44077ae
fix bug
AlexandraVolokhova Sep 21, 2023
13a739c
regrouped functions for geom
AlexandraVolokhova Sep 22, 2023
ae82682
add metrics script
AlexandraVolokhova Sep 22, 2023
50c64f2
add rdkit baseline generation
AlexandraVolokhova Sep 22, 2023
22beb23
progress with clustering
AlexandraVolokhova Sep 24, 2023
d1efce0
finilise metrics and rdk clustering
AlexandraVolokhova Sep 25, 2023
0742661
recent changes in metrics, top k, hack, etc
AlexandraVolokhova Sep 26, 2023
130d903
kde plots script initial
AlexandraVolokhova Sep 27, 2023
e32a342
Merge branch 'conformer_iclr' of github.com:alexhernandezgarcia/gflow…
AlexandraVolokhova Sep 27, 2023
2ab7185
edit kde plots
AlexandraVolokhova Sep 28, 2023
750b9db
Merge branch 'conformer_iclr' of github.com:alexhernandezgarcia/gflow…
AlexandraVolokhova Sep 28, 2023
2157983
after-deadline commit
AlexandraVolokhova Nov 13, 2023
436890f
black, isort
AlexandraVolokhova Nov 13, 2023
0d785d4
readme edit
AlexandraVolokhova Nov 13, 2023
8206e02
black, isort again
AlexandraVolokhova Nov 13, 2023
b9183a1
script for generating uniform samples with reward weights
AlexandraVolokhova Jan 11, 2024
8413883
finish mcmc baseline script
AlexandraVolokhova Jan 22, 2024
6adda4d
add rminus1_stop argument
AlexandraVolokhova Jan 23, 2024
6dcb1cd
black, hacked isort
AlexandraVolokhova Jan 23, 2024
c70c392
hacked isort
AlexandraVolokhova Jan 23, 2024
a3486fe
Merge branch 'conformer_ai4mat23' of github.com:alexhernandezgarcia/g…
AlexandraVolokhova Jan 23, 2024
43d9e82
change number of min samples to 1100
AlexandraVolokhova Jan 23, 2024
8fb1f82
updated smiles parsing
AlexandraVolokhova Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion gflownet/envs/conformers/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles


PREDEFINED_SMILES = [
"O=C(c1ccccc1)c1ccc2c(c1)OCCOCCOCCOCCO2",
"O=S(=O)(NN=C1CCCCCC1)c1ccc(Cl)cc1",
Expand Down
65 changes: 65 additions & 0 deletions gflownet/utils/molecule/geom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import os
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
from rdkit import Chem
from tqdm import tqdm

from gflownet.utils.molecule.rotatable_bonds import (
get_rotatable_ta_list,
is_hydrogen_ta,
)


def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None):
if summary_file is None:
drugs_file = base_path / "rdkit_folder/summary_drugs.json"
with open(drugs_file, "r") as f:
summary_file = json.load(f)

pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"]
if os.path.isfile(pickle_path):
with open(pickle_path, "rb") as f:
dic = pickle.load(f)
mol = dic["conformers"][conf_idx]["rd_mol"]
return mol


def get_all_confs_geom(base_path, smiles, summary_file=None):
if summary_file is None:
drugs_file = base_path / "rdkit_folder/summary_drugs.json"
with open(drugs_file, "r") as f:
summary_file = json.load(f)
try:
pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"]
if os.path.isfile(pickle_path):
with open(pickle_path, "rb") as f:
dic = pickle.load(f)
conformers = [x["rd_mol"] for x in dic["conformers"]]
return conformers
except KeyError:
print("No pickle_path file for {}".format(smiles))
return None


def get_rd_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
return mol


def has_same_can_smiles(mol1, mol2):
sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1))
sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2))
return sm1 == sm2


def all_same_graphs(mols):
ref = mols[0]
same = []
for mol in mols:
same.append(has_same_can_smiles(ref, mol))
return np.all(same)
37 changes: 37 additions & 0 deletions gflownet/utils/molecule/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# some functions inspired by: https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c

import copy

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdMolAlign as MA
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
from rdkit.Geometry.rdGeometry import Point3D


def get_best_rmsd(gen_mol, ref_mol):
gen_mol = Chem.RemoveHs(gen_mol)
ref_mol = Chem.RemoveHs(ref_mol)
rmsd = MA.GetBestRMS(gen_mol, ref_mol)
return rmsd


def get_cov_mat(ref_mols, gen_mols, threshold=1.25):
rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32)
for i, gen_mol in enumerate(gen_mols):
gen_mol_c = copy.deepcopy(gen_mol)
for j, ref_mol in enumerate(ref_mols):
ref_mol_c = copy.deepcopy(ref_mol)
rmsd_mat[j, i] = get_best_rmsd(gen_mol_c, ref_mol_c)
rmsd_mat_min = rmsd_mat.min(-1)
return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean()


def normalise_positions(mol):
conf = mol.GetConformer()
pos = conf.GetPositions()
pos = pos - pos.mean(axis=0)
for idx, p in enumerate(pos):
conf.SetAtomPosition(idx, Point3D(*p))
return mol
8 changes: 8 additions & 0 deletions gflownet/utils/molecule/rotatable_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,11 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id):
first = is_connected_to_three_hydrogens(mol, ta[1], ta[2])
second = is_connected_to_three_hydrogens(mol, ta[2], ta[1])
return first or second


def has_hydrogen_tas(mol):
tas = get_rotatable_ta_list(mol)
hydrogen_flags = []
for t in tas:
hydrogen_flags.append(is_hydrogen_ta(mol, t))
return np.any(hydrogen_flags)
22 changes: 22 additions & 0 deletions scripts/conformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
This folder contains scripts for dealing with GEOM dataset and running RDKit-base baselines.

## Calculating statistics for GEOM dataset
The script `geom_stats.py` extracts statistical information from molecular conformation data in the GEOM dataset using the RDKit library. The GEOM dataset is expected to be in the "rdkit_folder" format (tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master). This script parses the dataset, calculates relevant statistics, and outputs the results to a CSV file.

Statistics collected include:
* SMILES representation of the molecule.
* Whether the molecule is self-consistent, i.e. its conformations in the dataset correspond to the same SMILES.
* Whether the the molecule is consistent with the RDKit, i.e. all conformations in the dataset correspond to the same SMILES and this SMILES is the same as stored in the dataset.
* The number of rotatable torsion angles in the molecular conformation (both from GEOM conformers and RDKit-generated graph).
* Whether the molecule contains hydrogen torsion angles.
* The total number of unique conformations for the molecule.
* The number of heavy atoms in the molecule.
* The total number of atoms in the molecule.

### Usage

You can use the script with the following command-line arguments:

* `--geom_dir` (optional): Path to the directory containing the GEOM dataset in an rdkit_folder. The default path is '/home/mila/a/alexandra.volokhova/scratch/datasets/geom'.
* `--output_file` (optional): Path to the output CSV file where the statistics will be saved. The default path is './geom_stats.csv'.

189 changes: 189 additions & 0 deletions scripts/conformer/compute_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import json
import os
import pickle
import random
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm

from gflownet.utils.molecule.geom import get_all_confs_geom
from gflownet.utils.molecule.metrics import get_best_rmsd, get_cov_mat
from gflownet.envs.conformers.conformer import PREDEFINED_SMILES


def distant_enough(conf, others, delta):
conf = deepcopy(conf)
dist = []
for c in others:
dist.append(get_best_rmsd(deepcopy(c), conf))
dist = np.array(dist)
return np.all(dist > delta)


def get_diverse_top_k(confs_list, k=None, delta=1.25):
"""confs_list should be sorted according to the energy! lowest energy first"""
result = [confs_list[0]]
for conf in confs_list[1:]:
if distant_enough(conf, result, delta):
result.append(conf)
if len(result) >= k:
break
if len(result) < k:
print(
f"Cannot find {k} different conformations in {len(confs_list)} for delta {delta}"
)
print(f"Found only {len(result)}. Adding random from generated")
result += random.sample(confs_list.tolist(), k - len(result))
return result


def get_smiles_from_filename(filename):
smiles = filename.split("_")[1]
if smiles == 'mcmc':
smiles_idx = int(filename.split("_")[2][5:])
smiles = PREDEFINED_SMILES[smiles_idx]
if smiles.endswith(".pkl"):
smiles = smiles[:-4]
return smiles


def main(args):
base_path = Path(args.geom_dir)
drugs_file = base_path / "rdkit_folder/summary_drugs.json"
with open(drugs_file, "r") as f:
drugs_summ = json.load(f)

filenames = [Path(args.gen_dir) / x for x in os.listdir(args.gen_dir)]
gen_files = []
gen_confs = []
smiles = []
energies = []
for fp in filenames:
with open(fp, "rb") as f:
gen_files.append(pickle.load(f))
# 1k cut-off to use the same max number of gen samples for all methods
gen_confs.append(gen_files[-1]["conformer"][:1000])
if "smiles" in gen_files[-1].keys():
smiles.append(gen_files[-1]["smiles"])
if "energy" in gen_files[-1].keys():
energies.append(gen_files[-1]["energy"][:1000])
if len(smiles) == 0:
smiles = [get_smiles_from_filename(x.name) for x in filenames]
print("All smiles")
print(*smiles, sep="\n")
ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles]
# filter out nans
gen_confs = [gen_confs[idx] for idx, val in enumerate(ref_confs) if val is not None]
smiles = [smiles[idx] for idx, val in enumerate(ref_confs) if val is not None]
if len(energies) > 0:
energies = [
energies[idx] for idx, val in enumerate(ref_confs) if val is not None
]
ref_confs = [val for val in ref_confs if val is not None]
assert len(gen_confs) == len(ref_confs) == len(smiles)

if args.use_top_k:
if len(energies) == 0:
raise Exception("Cannot use top-k without energies")
energies = [np.array(e) for e in energies]
indecies = [np.argsort(e) for e in energies]
gen_confs = [np.array(x)[idx] for x, idx in zip(gen_confs, indecies)]
if args.diverse:
gen_confs = [
get_diverse_top_k(x, k=len(ref) * 2, delta=args.delta)
for x, ref in zip(gen_confs, ref_confs)
]
if not args.hack:
gen_confs = [x[: 2 * len(ref)] for x, ref in zip(gen_confs, ref_confs)]

geom_stats = pd.read_csv(args.geom_stats, index_col=0)
cov_list = []
mat_list = []
n_conf = []
n_tas = []
consistent = []
has_h_tas = []
hack = False
for ref, gen, sm in tqdm(zip(ref_confs, gen_confs, smiles), total=len(ref_confs)):
if len(gen) < 2 * len(ref):
print("Recieved less samples that needed for computing metrics!")
print(
f"Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs"
)
try:
if len(gen) > 2 * len(ref):
hack = True
print(
f"Warning! Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs"
)
cov, mat = get_cov_mat(ref, gen, threshold=1.25)
except RuntimeError as e:
print(e)
cov, mat = None, None
cov_list.append(cov)
mat_list.append(mat)
n_conf.append(len(ref))
n_tas.append(
geom_stats[geom_stats.smiles == sm].n_rotatable_torsion_angles_rdkit.values[
0
]
)
consistent.append(
geom_stats[geom_stats.smiles == sm].rdkit_consistent.values[0]
)
has_h_tas.append(geom_stats[geom_stats.smiles == sm].has_hydrogen_tas.values[0])

data = {
"smiles": smiles,
"cov": cov_list,
"mat": mat_list,
"n_ref_confs": n_conf,
"n_tas": n_tas,
"has_hydrogen_tas": has_h_tas,
"consistent": consistent,
}
df = pd.DataFrame(data)
name = Path(args.gen_dir).name
if name in ["xtb", "gfn-ff", "torchani"]:
name = Path(args.gen_dir).name
name = Path(args.gen_dir).parent.name + "_" + name
if args.use_top_k:
name += "_top_k"
if args.diverse:
name += f"_diverse_{args.delta}"
if hack:
name += "_hacked"

output_file = Path(args.output_dir) / "{}_metrics.csv".format(name)
df.to_csv(output_file, index=False)
print("Saved metrics at {}".format(output_file))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--geom_dir",
type=str,
default="/home/mila/a/alexandra.volokhova/scratch/datasets/geom",
)
parser.add_argument(
"--output_dir",
type=str,
default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics",
)
parser.add_argument(
"--geom_stats",
type=str,
default="/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/generated_files/geom_stats.csv",
)
parser.add_argument("--gen_dir", type=str, default="./")
parser.add_argument("--use_top_k", type=bool, default=False)
parser.add_argument("--diverse", type=bool, default=False)
parser.add_argument("--delta", type=float, default=1.0)
parser.add_argument("--hack", type=bool, default=False)
args = parser.parse_args()
main(args)
Loading
Loading