diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 4b831df9a..fe3778879 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -15,7 +15,6 @@ from gflownet.utils.molecule.rdkit_conformer import RDKitConformer from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles - PREDEFINED_SMILES = [ "O=C(c1ccccc1)c1ccc2c(c1)OCCOCCOCCOCCO2", "O=S(=O)(NN=C1CCCCCC1)c1ccc(Cl)cc1", diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py new file mode 100644 index 000000000..e7ac71308 --- /dev/null +++ b/gflownet/utils/molecule/geom.py @@ -0,0 +1,65 @@ +import json +import os +import pickle +from pathlib import Path + +import numpy as np +import pandas as pd +from rdkit import Chem +from tqdm import tqdm + +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + is_hydrogen_ta, +) + + +def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): + if summary_file is None: + drugs_file = base_path / "rdkit_folder/summary_drugs.json" + with open(drugs_file, "r") as f: + summary_file = json.load(f) + + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + mol = dic["conformers"][conf_idx]["rd_mol"] + return mol + + +def get_all_confs_geom(base_path, smiles, summary_file=None): + if summary_file is None: + drugs_file = base_path / "rdkit_folder/summary_drugs.json" + with open(drugs_file, "r") as f: + summary_file = json.load(f) + try: + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + conformers = [x["rd_mol"] for x in dic["conformers"]] + return conformers + except KeyError: + print("No pickle_path file for {}".format(smiles)) + return None + + +def get_rd_mol(smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + return mol + + +def has_same_can_smiles(mol1, mol2): + sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1)) + sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2)) + return sm1 == sm2 + + +def all_same_graphs(mols): + ref = mols[0] + same = [] + for mol in mols: + same.append(has_same_can_smiles(ref, mol)) + return np.all(same) diff --git a/gflownet/utils/molecule/metrics.py b/gflownet/utils/molecule/metrics.py new file mode 100644 index 000000000..f52dc9f38 --- /dev/null +++ b/gflownet/utils/molecule/metrics.py @@ -0,0 +1,37 @@ +# some functions inspired by: https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + +import copy + +import numpy as np +import pandas as pd +from rdkit import Chem +from rdkit.Chem import rdMolAlign as MA +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule +from rdkit.Geometry.rdGeometry import Point3D + + +def get_best_rmsd(gen_mol, ref_mol): + gen_mol = Chem.RemoveHs(gen_mol) + ref_mol = Chem.RemoveHs(ref_mol) + rmsd = MA.GetBestRMS(gen_mol, ref_mol) + return rmsd + + +def get_cov_mat(ref_mols, gen_mols, threshold=1.25): + rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32) + for i, gen_mol in enumerate(gen_mols): + gen_mol_c = copy.deepcopy(gen_mol) + for j, ref_mol in enumerate(ref_mols): + ref_mol_c = copy.deepcopy(ref_mol) + rmsd_mat[j, i] = get_best_rmsd(gen_mol_c, ref_mol_c) + rmsd_mat_min = rmsd_mat.min(-1) + return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean() + + +def normalise_positions(mol): + conf = mol.GetConformer() + pos = conf.GetPositions() + pos = pos - pos.mean(axis=0) + for idx, p in enumerate(pos): + conf.SetAtomPosition(idx, Point3D(*p)) + return mol diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 4dc73a590..4a7c7cfe3 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -111,3 +111,11 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id): first = is_connected_to_three_hydrogens(mol, ta[1], ta[2]) second = is_connected_to_three_hydrogens(mol, ta[2], ta[1]) return first or second + + +def has_hydrogen_tas(mol): + tas = get_rotatable_ta_list(mol) + hydrogen_flags = [] + for t in tas: + hydrogen_flags.append(is_hydrogen_ta(mol, t)) + return np.any(hydrogen_flags) diff --git a/scripts/conformer/README.md b/scripts/conformer/README.md new file mode 100644 index 000000000..bd87be9f0 --- /dev/null +++ b/scripts/conformer/README.md @@ -0,0 +1,22 @@ +This folder contains scripts for dealing with GEOM dataset and running RDKit-base baselines. + +## Calculating statistics for GEOM dataset +The script `geom_stats.py` extracts statistical information from molecular conformation data in the GEOM dataset using the RDKit library. The GEOM dataset is expected to be in the "rdkit_folder" format (tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master). This script parses the dataset, calculates relevant statistics, and outputs the results to a CSV file. + +Statistics collected include: +* SMILES representation of the molecule. +* Whether the molecule is self-consistent, i.e. its conformations in the dataset correspond to the same SMILES. +* Whether the the molecule is consistent with the RDKit, i.e. all conformations in the dataset correspond to the same SMILES and this SMILES is the same as stored in the dataset. +* The number of rotatable torsion angles in the molecular conformation (both from GEOM conformers and RDKit-generated graph). +* Whether the molecule contains hydrogen torsion angles. +* The total number of unique conformations for the molecule. +* The number of heavy atoms in the molecule. +* The total number of atoms in the molecule. + +### Usage + +You can use the script with the following command-line arguments: + +* `--geom_dir` (optional): Path to the directory containing the GEOM dataset in an rdkit_folder. The default path is '/home/mila/a/alexandra.volokhova/scratch/datasets/geom'. +* `--output_file` (optional): Path to the output CSV file where the statistics will be saved. The default path is './geom_stats.csv'. + diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py new file mode 100644 index 000000000..df6eb8993 --- /dev/null +++ b/scripts/conformer/compute_metrics.py @@ -0,0 +1,189 @@ +import argparse +import json +import os +import pickle +import random +from copy import deepcopy +from pathlib import Path + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from gflownet.utils.molecule.geom import get_all_confs_geom +from gflownet.utils.molecule.metrics import get_best_rmsd, get_cov_mat +from gflownet.envs.conformers.conformer import PREDEFINED_SMILES + + +def distant_enough(conf, others, delta): + conf = deepcopy(conf) + dist = [] + for c in others: + dist.append(get_best_rmsd(deepcopy(c), conf)) + dist = np.array(dist) + return np.all(dist > delta) + + +def get_diverse_top_k(confs_list, k=None, delta=1.25): + """confs_list should be sorted according to the energy! lowest energy first""" + result = [confs_list[0]] + for conf in confs_list[1:]: + if distant_enough(conf, result, delta): + result.append(conf) + if len(result) >= k: + break + if len(result) < k: + print( + f"Cannot find {k} different conformations in {len(confs_list)} for delta {delta}" + ) + print(f"Found only {len(result)}. Adding random from generated") + result += random.sample(confs_list.tolist(), k - len(result)) + return result + + +def get_smiles_from_filename(filename): + smiles = filename.split("_")[1] + if smiles == 'mcmc': + smiles_idx = int(filename.split("_")[2][5:]) + smiles = PREDEFINED_SMILES[smiles_idx] + if smiles.endswith(".pkl"): + smiles = smiles[:-4] + return smiles + + +def main(args): + base_path = Path(args.geom_dir) + drugs_file = base_path / "rdkit_folder/summary_drugs.json" + with open(drugs_file, "r") as f: + drugs_summ = json.load(f) + + filenames = [Path(args.gen_dir) / x for x in os.listdir(args.gen_dir)] + gen_files = [] + gen_confs = [] + smiles = [] + energies = [] + for fp in filenames: + with open(fp, "rb") as f: + gen_files.append(pickle.load(f)) + # 1k cut-off to use the same max number of gen samples for all methods + gen_confs.append(gen_files[-1]["conformer"][:1000]) + if "smiles" in gen_files[-1].keys(): + smiles.append(gen_files[-1]["smiles"]) + if "energy" in gen_files[-1].keys(): + energies.append(gen_files[-1]["energy"][:1000]) + if len(smiles) == 0: + smiles = [get_smiles_from_filename(x.name) for x in filenames] + print("All smiles") + print(*smiles, sep="\n") + ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + # filter out nans + gen_confs = [gen_confs[idx] for idx, val in enumerate(ref_confs) if val is not None] + smiles = [smiles[idx] for idx, val in enumerate(ref_confs) if val is not None] + if len(energies) > 0: + energies = [ + energies[idx] for idx, val in enumerate(ref_confs) if val is not None + ] + ref_confs = [val for val in ref_confs if val is not None] + assert len(gen_confs) == len(ref_confs) == len(smiles) + + if args.use_top_k: + if len(energies) == 0: + raise Exception("Cannot use top-k without energies") + energies = [np.array(e) for e in energies] + indecies = [np.argsort(e) for e in energies] + gen_confs = [np.array(x)[idx] for x, idx in zip(gen_confs, indecies)] + if args.diverse: + gen_confs = [ + get_diverse_top_k(x, k=len(ref) * 2, delta=args.delta) + for x, ref in zip(gen_confs, ref_confs) + ] + if not args.hack: + gen_confs = [x[: 2 * len(ref)] for x, ref in zip(gen_confs, ref_confs)] + + geom_stats = pd.read_csv(args.geom_stats, index_col=0) + cov_list = [] + mat_list = [] + n_conf = [] + n_tas = [] + consistent = [] + has_h_tas = [] + hack = False + for ref, gen, sm in tqdm(zip(ref_confs, gen_confs, smiles), total=len(ref_confs)): + if len(gen) < 2 * len(ref): + print("Recieved less samples that needed for computing metrics!") + print( + f"Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs" + ) + try: + if len(gen) > 2 * len(ref): + hack = True + print( + f"Warning! Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs" + ) + cov, mat = get_cov_mat(ref, gen, threshold=1.25) + except RuntimeError as e: + print(e) + cov, mat = None, None + cov_list.append(cov) + mat_list.append(mat) + n_conf.append(len(ref)) + n_tas.append( + geom_stats[geom_stats.smiles == sm].n_rotatable_torsion_angles_rdkit.values[ + 0 + ] + ) + consistent.append( + geom_stats[geom_stats.smiles == sm].rdkit_consistent.values[0] + ) + has_h_tas.append(geom_stats[geom_stats.smiles == sm].has_hydrogen_tas.values[0]) + + data = { + "smiles": smiles, + "cov": cov_list, + "mat": mat_list, + "n_ref_confs": n_conf, + "n_tas": n_tas, + "has_hydrogen_tas": has_h_tas, + "consistent": consistent, + } + df = pd.DataFrame(data) + name = Path(args.gen_dir).name + if name in ["xtb", "gfn-ff", "torchani"]: + name = Path(args.gen_dir).name + name = Path(args.gen_dir).parent.name + "_" + name + if args.use_top_k: + name += "_top_k" + if args.diverse: + name += f"_diverse_{args.delta}" + if hack: + name += "_hacked" + + output_file = Path(args.output_dir) / "{}_metrics.csv".format(name) + df.to_csv(output_file, index=False) + print("Saved metrics at {}".format(output_file)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--geom_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/scratch/datasets/geom", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics", + ) + parser.add_argument( + "--geom_stats", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/generated_files/geom_stats.csv", + ) + parser.add_argument("--gen_dir", type=str, default="./") + parser.add_argument("--use_top_k", type=bool, default=False) + parser.add_argument("--diverse", type=bool, default=False) + parser.add_argument("--delta", type=float, default=1.0) + parser.add_argument("--hack", type=bool, default=False) + args = parser.parse_args() + main(args) diff --git a/scripts/conformer/gen_uniform_samples_with_rewards.py b/scripts/conformer/gen_uniform_samples_with_rewards.py new file mode 100644 index 000000000..1235433a2 --- /dev/null +++ b/scripts/conformer/gen_uniform_samples_with_rewards.py @@ -0,0 +1,103 @@ +# tblite import should stay here first! othervise everything fails with tblite errors +import argparse +import os +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from gflownet.envs.conformers.conformer import PREDEFINED_SMILES, Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.utils.molecule.rotatable_bonds import ( + find_rotor_from_smiles, + get_rotatable_ta_list, +) +from rdkit.Chem import AllChem, rdMolTransforms +from tqdm import tqdm + + +def get_uniform_samples_and_energy_weights(smiles, n_samples, energy_model="torchani"): + n_torsion_angles = len(find_rotor_from_smiles(smiles)) + uniform_tas = np.random.uniform(0.0, 2 * np.pi, size=(n_samples, n_torsion_angles)) + env = Conformer( + smiles=smiles, + n_torsion_angles=n_torsion_angles, + reward_func="boltzmann", + reward_beta=32, + ) + if energy_model == "torchani": + proxy = TorchANIMoleculeEnergy(device="cpu", float_precision=32) + elif energy_model == "tblite": + proxy = TBLiteMoleculeEnergy(device="cpu", float_precision=32, batch_size=417) + else: + raise NotImplementedError( + f"No proxy availabe for {energy_model}, use one of ['torchani', 'tblite']" + ) + proxy.setup(env) + + def get_weights(batch): + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + return rewards.numpy() + + weights = get_weights(uniform_tas) + ddict = {f"ta_{idx}": uniform_tas[:, idx] for idx in range(n_torsion_angles)} + ddict.update({"weights": weights}) + return pd.DataFrame(ddict) + + +def main(args): + output_root = Path(args.output_dir) + if not output_root.exists(): + os.mkdir(output_root) + + samples_root = Path( + "/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples" + ) + selected_mols = pd.read_csv( + samples_root / "gfn_samples_2-12" / "torchani_selected.csv" + ) + result = dict() + for smiles in tqdm(selected_mols["SMILES"].values): + result.update( + { + smiles: get_uniform_samples_and_energy_weights( + smiles, args.n_samples, args.energy_model + ) + } + ) + if args.save_each_df: + sm_idx = PREDEFINED_SMILES.index(smiles) + filename = filename = ( + output_root + / f"{args.energy_model}_{sm_idx}_weighted_samples_selected_smiles.csv" + ) + result[smiles].to_csv(filename) + filename = output_root / f"{args.energy_model}_weighted_samples_selected_smiles.pkl" + with open(filename, "wb") as file: + pickle.dump(result, file) + print(f"saved results at {filename}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--n_samples", + type=int, + default=1000, + ) + parser.add_argument( + "--energy_model", type=str, choices=["torchani", "tblite"], default="torchani" + ) + parser.add_argument("--save_each_df", type=bool, default=False) + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/uniform_samples_with_reward_weights", + ) + args = parser.parse_args() + main(args) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py new file mode 100644 index 000000000..5194b1068 --- /dev/null +++ b/scripts/conformer/geom_stats.py @@ -0,0 +1,77 @@ +import argparse +import json +import os +import pickle +from pathlib import Path + +import numpy as np +import pandas as pd +from rdkit import Chem +from tqdm import tqdm + +from gflownet.utils.molecule.geom import ( + all_same_graphs, + get_all_confs_geom, + get_conf_geom, + get_rd_mol, +) +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + has_hydrogen_tas, +) + +""" +Here we use rdkit_folder format of the GEOM dataset +Tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master +""" +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--geom_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/scratch/datasets/geom", + ) + parser.add_argument("--output_file", type=str, default="./geom_stats.csv") + args = parser.parse_args() + + base_path = Path(args.geom_dir) + drugs_file = base_path / "rdkit_folder/summary_drugs.json" + with open(drugs_file, "r") as f: + drugs_summ = json.load(f) + + smiles = [] + self_consistent = [] + rdkit_consistent = [] + n_tas_geom = [] + n_tas_rdkit = [] + unique_confs = [] + n_atoms = [] + n_heavy_atoms = [] + hydrogen_tas = [] + for sm, sub_dic in tqdm(list(drugs_summ.items())): + confs = get_all_confs_geom(base_path, sm, summary_file=drugs_summ) + if not confs is None: + rd_mol = get_rd_mol(sm) + smiles.append(sm) + unique_confs.append(len(confs)) + n_atoms.append(confs[0].GetNumAtoms()) + n_heavy_atoms.append(confs[0].GetNumHeavyAtoms()) + self_consistent.append(all_same_graphs(confs)) + rdkit_consistent.append(all_same_graphs(confs + [rd_mol])) + n_tas_geom.append(len(get_rotatable_ta_list(confs[0]))) + n_tas_rdkit.append(len(get_rotatable_ta_list(rd_mol))) + hydrogen_tas.append(has_hydrogen_tas(confs[0])) + + data = { + "smiles": smiles, + "self_consistent": self_consistent, + "rdkit_consistent": rdkit_consistent, + "n_rotatable_torsion_angles_geom": n_tas_geom, + "n_rotatable_torsion_angles_rdkit": n_tas_rdkit, + "has_hydrogen_tas": hydrogen_tas, + "n_confs": unique_confs, + "n_heavy_atoms": n_heavy_atoms, + "n_atoms": n_atoms, + } + df = pd.DataFrame(data) + df.to_csv(args.output_file) diff --git a/scripts/conformer/kde_plots.py b/scripts/conformer/kde_plots.py new file mode 100644 index 000000000..9ceed35bf --- /dev/null +++ b/scripts/conformer/kde_plots.py @@ -0,0 +1,138 @@ +# IMPORT THIS FIRST!!!!! +import argparse +import os +import pickle +import time +from datetime import datetime +from pathlib import Path + +import numpy as np +from scipy.special import logsumexp +from tblite import interface +from tqdm import tqdm + +from gflownet.envs.conformers.conformer import Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy +from gflownet.utils.common import torch2np + +PROXY_DICT = { + "tblite": TBLiteMoleculeEnergy, + "xtb": XTBMoleculeEnergy, + "torchani": TorchANIMoleculeEnergy, +} +PROXY_NAME_DICT = {"tblite": "xtb", "xtb": "gfn-ff", "torchani": "torchani"} + + +def get_smiles_and_proxy_class(filename): + sm = filename.split("_")[2] + proxy_name = filename.split("_")[-1][:-4] + proxy_class = PROXY_DICT[proxy_name] + proxy = proxy_class(device="cpu", float_precision=32) + env = Conformer( + smiles=sm, + n_torsion_angles=2, + reward_func="boltzmann", + reward_beta=32, + proxy=proxy, + reward_sampling_method="nested", + ) + # proxy.setup(env) + return sm, PROXY_NAME_DICT[proxy_name], proxy, env + + +def load_samples(filename, base_path): + path = base_path / filename + with open(path, "rb") as f: + data = pickle.load(f) + return data["x"] + + +def get_true_kde(env, n_samples, bandwidth=0.1): + x_from_reward = env.sample_from_reward(n_samples=n_samples) + x_from_reward = torch2np(env.statetorch2kde(x_from_reward)) + kde_true = env.fit_kde(x_from_reward, kernel="gaussian", bandwidth=bandwidth) + return kde_true + + +def get_metrics(kde_true, kde_pred, test_samples): + scores_true = kde_true.score_samples(test_samples) + log_density_true = scores_true - logsumexp(scores_true, axis=0) + scores_pred = kde_pred.score_samples(test_samples) + log_density_pred = scores_pred - logsumexp(scores_pred, axis=0) + density_true = np.exp(log_density_true) + density_pred = np.exp(log_density_pred) + # L1 error + l1 = np.abs(density_pred - density_true).mean() + # KL divergence + kl = (density_true * (log_density_true - log_density_pred)).mean() + # Jensen-Shannon divergence + log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) + jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) + jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + return l1, kl, jsd + + +def main(args): + base_path = Path(args.samples_dir) + output_base = Path(args.output_dir) + if not output_base.exists(): + os.mkdir(output_base) + for filename in tqdm(os.listdir(base_path)): + ct = time.time() + print(f'{datetime.now().strftime("%H-%M-%S")}: Initialising env') + smiles, pr_name, proxy, env = get_smiles_and_proxy_class(filename) + + # create output dir + current_datetime = datetime.now() + timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") + out_name = f"{pr_name}_{smiles}_{args.n_test}_{args.bandwidth}_{timestamp}" + output_dir = output_base / out_name + os.mkdir(output_dir) + print(f"Will save results at {output_dir}") + + samples = load_samples(filename, base_path) + n_samples = samples.shape[0] + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing true kde') + kde_true = get_true_kde(env, n_samples, bandwidth=args.bandwidth) + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing pred kde') + kde_pred = env.fit_kde(samples, kernel="gaussian", bandwidth=args.bandwidth) + print(f'{datetime.now().strftime("%H-%M-%S")}: Making figures') + fig_pred = env.plot_kde(kde_pred) + fig_pred.savefig( + output_dir / f"kde_pred_{out_name}.png", bbox_inches="tight", format="png" + ) + fig_true = env.plot_kde(kde_true) + fig_true.savefig( + output_dir / f"kde_true_{out_name}.png", bbox_inches="tight", format="png" + ) + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing metrics') + test_samples = np.array(env.get_grid_terminating_states(args.n_test))[:, :2] + l1, kl, jsd = get_metrics(kde_true, kde_pred, test_samples) + met = {"l1": l1, "kl": kl, "jsd": jsd} + # write stuff + with open(output_dir / "metrics.pkl", "wb") as file: + pickle.dump(met, file) + with open(output_dir / "kde_true.pkl", "wb") as file: + pickle.dump(kde_true, file) + with open(output_dir / "kde_pred.pkl", "wb") as file: + pickle.dump(kde_pred, file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--samples_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/mcmc_samples", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/kde_stats/mcmc", + ) + parser.add_argument("--bandwidth", type=float, default=0.15) + parser.add_argument("--n_test", type=int, default=10000) + args = parser.parse_args() + main(args) diff --git a/scripts/conformer/mcmc_baseline.py b/scripts/conformer/mcmc_baseline.py new file mode 100644 index 000000000..db3542a03 --- /dev/null +++ b/scripts/conformer/mcmc_baseline.py @@ -0,0 +1,282 @@ +try: + from tblite import interface +except: + pass + +import argparse +import os +import pickle +from pathlib import Path + +import getdist.plots as gdplt +import numpy as np +import torch +from cobaya.run import run +from getdist.mcsamples import MCSamplesFromCobaya +from gflownet.envs.conformers.conformer import Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy +from scipy import stats + + +def convert_to_numpy_if_needed(array): + if torch.is_tensor(array): + return array.cpu().detach().numpy() + else: + return array + + +def main(args): + if args.proxy_name == "torchani": + proxy_class = TorchANIMoleculeEnergy + elif args.proxy_name == "tblite": + proxy_class = TBLiteMoleculeEnergy + elif args.proxy_name == "xtb": + proxy_class = XTBMoleculeEnergy + + # Leave as is + DEVICE = "cpu" + FLOAT_PRECISION = 32 + REWARD_FUNC = "boltzmann" + REWARD_BETA = 32 + + # output dir + output_dir = Path(args.output_dir) + if not output_dir.exists(): + os.mkdir(output_dir) + + for smile in args.ids: + # Change n_torsion_angles + env = Conformer( + smiles=int(smile), + n_torsion_angles=-1, + reward_func=REWARD_FUNC, + reward_beta=REWARD_BETA, + ) + + ndims = len(env.torsion_angles) + + proxy = proxy_class(device=DEVICE, float_precision=FLOAT_PRECISION) + proxy.setup(env) + + print(f"Sampling for {ndims} dimensions with {args.proxy_name} proxy") + + if ndims == 2: + + def reward(p0, p1): + batch = np.concatenate([[p0], [p1]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 3: + + def reward(p0, p1, p2): + batch = np.concatenate([[p0], [p1], [p2]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 4: + + def reward(p0, p1, p2, p3): + batch = np.concatenate([[p0], [p1], [p2], [p3]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + if ndims == 5: + + def reward(p0, p1, p2, p3, p4): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + if ndims == 6: + + def reward(p0, p1, p2, p3, p4, p5): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5]]).reshape( + 1, -1 + ) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 7: + + def reward(p0, p1, p2, p3, p4, p5, p6): + batch = np.concatenate( + [[p0], [p1], [p2], [p2], [p3], [p4], [p5], [p6]] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 8: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7]] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 9: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8]] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 10: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9]] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 11: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10]] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + elif ndims == 12: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11): + batch = np.concatenate( + [ + [p0], + [p1], + [p2], + [p3], + [p4], + [p5], + [p6], + [p7], + [p8], + [p9], + [p10], + [p11], + ] + ).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + info = {"likelihood": {"reward": reward}} + + info["params"] = { + f"p{i}": { + "prior": {"min": 0, "max": 2 * np.pi}, + "ref": np.pi, + "proposal": 0.01, + } + for i in range(ndims) + } + + Rminus1_stop = args.rminus1_stop + info["sampler"] = {"mcmc": {"Rminus1_stop": Rminus1_stop, "max_tries": 1000}} + + updated_info, sampler = run(info) + + gdsamples = MCSamplesFromCobaya(updated_info, sampler.products()["sample"]) + + def get_energy(batch): + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + return energies + + if gdsamples.samples.shape[0] >= 1100: + npars = len(info["params"]) + dct = { + "x": gdsamples.samples[-1000:, :npars] + } # , "energy": np.exp(gdsamples.loglikes[-10000:])} + + dct["energy"] = get_energy(gdsamples.samples[-1000:, :npars]) + + dct["conformer"] = [ + env.set_conformer(state).rdk_mol + for state in gdsamples.samples[-1000:, :npars] + ] + pickle.dump( + dct, + open( + output_dir / f"conformers_mcmc_smile{smile}_{args.proxy_name}.pkl", + "wb", + ), + ) + print(f"Finished smile {smile} (dimensions {ndims})") + + else: + print(f"Not enough samples for smile {smile}") + return 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ids", nargs="+", required=True, type=int) + parser.add_argument("--output_dir", type=str, default="./mcmc_outputs/") + parser.add_argument("--proxy_name", type=str, default="torchani") + parser.add_argument("--rminus1_stop", type=float, default=0.05) + args = parser.parse_args() + main(args) diff --git a/scripts/conformer/merge_metrics.py b/scripts/conformer/merge_metrics.py new file mode 100644 index 000000000..a73922c2e --- /dev/null +++ b/scripts/conformer/merge_metrics.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import pandas as pd + +# method = "gfn-ff" +# method = "torchani" +method = "xtb" + +base_path = Path( + "/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics/" +) +df = pd.read_csv( + base_path / f"gfn_conformers_v2_{method}_metrics.csv", index_col="smiles" +) +dftop = pd.read_csv( + base_path / f"gfn_conformers_v2_{method}_top_k_metrics.csv", index_col="smiles" +) +rd = pd.read_csv( + base_path + / "rdkit_samples_target_smiles_first_batch_2023-09-24_21-15-42_metrics.csv", + index_col="smiles", +) +rdc = pd.read_csv( + base_path + / "rdkit_cluster_samples_target_smiles_first_batch_2023-09-24_21-16-23_metrics.csv", + index_col="smiles", +) +rd = rd.loc[df.index] +rdc = rdc.loc[df.index] +dftop = dftop.loc[df.index] +df[f"{method}_cov"] = df["cov"] +df[f"{method}_mat"] = df["mat"] +df[f"{method}_tk_cov"] = dftop["cov"] +df[f"{method}_tk_mat"] = dftop["mat"] +df["rdkit_cov"] = rd["cov"] +df["rdkit_mat"] = rd["mat"] +df["rdkit_cl_cov"] = rdc["cov"] +df["rdkit_cl_mat"] = rdc["mat"] +df = df.sort_values("n_tas") +df.to_csv("./merged_metrics_{}.csv".format(method)) +df = df[ + [ + f"{method}_cov", + f"{method}_mat", + f"{method}_tk_cov", + f"{method}_tk_mat", + "rdkit_cov", + "rdkit_mat", + "rdkit_cl_cov", + "rdkit_cl_mat", + ] +] +print("Mean") +print(df.mean()) +print("Median") +print(df.median()) +print("Var") +print(df.var()) diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py new file mode 100644 index 000000000..3606599fe --- /dev/null +++ b/scripts/conformer/rdkit_baselines.py @@ -0,0 +1,223 @@ +import argparse +import copy +import os +import pickle +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdMolAlign as MA +from rdkit.Chem import rdMolTransforms +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule +from rdkit.Geometry.rdGeometry import Point3D +from scipy.spatial.transform import Rotation +from sklearn.cluster import KMeans +from tqdm import tqdm + + +def gen_multiple_conf_rdkit(smiles, n_confs, optimise=True, randomise_tas=False): + mols = [] + for _ in range(n_confs): + mols.append( + get_single_conf_rdkit( + smiles, optimise=optimise, randomise_tas=randomise_tas + ) + ) + return mols + + +def get_single_conf_rdkit(smiles, optimise=True, randomise_tas=False): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + if optimise: + try: + AllChem.MMFFOptimizeMolecule(mol, confId=0) + except Exception as e: + print(e) + if randomise_tas: + rotable_bonds = get_torsions(mol) + values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds)) + conf = mol.GetConformers()[0] + for rb, val in zip(rotable_bonds, values): + rdMolTransforms.SetDihedralRad(conf, rb[0], rb[1], rb[2], rb[3], val) + Chem.rdMolTransforms.CanonicalizeConformer(conf) + return mol + + +def write_conformers(conf_list, smiles, output_dir, prefix="", idx=None): + conf_dict = {"conformer": conf_list, "smiles": smiles} + filename = output_dir / f"{prefix}conformer_{idx}.pkl" + with open(filename, "wb") as file: + pickle.dump(conf_dict, file) + + +def get_torsions(m): + # taken from https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + m = Chem.RemoveHs(copy.deepcopy(m)) + torsionList = [] + torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" + torsionQuery = Chem.MolFromSmarts(torsionSmarts) + matches = m.GetSubstructMatches(torsionQuery) + for match in matches: + idx2 = match[0] + idx3 = match[1] + bond = m.GetBondBetweenAtoms(idx2, idx3) + jAtom = m.GetAtomWithIdx(idx2) + kAtom = m.GetAtomWithIdx(idx3) + for b1 in jAtom.GetBonds(): + if b1.GetIdx() == bond.GetIdx(): + continue + idx1 = b1.GetOtherAtomIdx(idx2) + for b2 in kAtom.GetBonds(): + if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()): + continue + idx4 = b2.GetOtherAtomIdx(idx3) + # skip 3-membered rings + if idx4 == idx1: + continue + # skip torsions that include hydrogens + if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or ( + m.GetAtomWithIdx(idx4).GetAtomicNum() == 1 + ): + continue + if m.GetAtomWithIdx(idx4).IsInRing(): + torsionList.append((idx4, idx3, idx2, idx1)) + break + else: + torsionList.append((idx1, idx2, idx3, idx4)) + break + break + return torsionList + + +def clustering(smiles, M=1000, N=100): + # adapted from https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + total_sz = 0 + rdkit_coords_list = [] + + # add MMFF optimize conformers, 20x + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=M, optimise=True, randomise_tas=False + ) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + # normalize + tgt_coords = rdkit_mols[0].GetConformers()[0].GetPositions().astype(np.float32) + tgt_coords = tgt_coords - np.mean(tgt_coords, axis=0) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # add no MMFF optimize conformers, 5x + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=int(M // 4), optimise=False, randomise_tas=False + ) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + ### add uniform rotation bonds conformers, 5x + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=int(M // 4), optimise=False, randomise_tas=True + ) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # clustering + rdkit_coords_flatten = np.array(rdkit_coords_list).reshape(total_sz, -1) + cluster_size = N + kmeans = KMeans(n_clusters=cluster_size, random_state=42).fit(rdkit_coords_flatten) + ids = kmeans.predict(rdkit_coords_flatten) + # get cluster center + center_coords = kmeans.cluster_centers_ + coords_list = [center_coords[i].reshape(-1, 3) for i in range(cluster_size)] + mols = [] + for coord in coords_list: + mol = get_single_conf_rdkit(smiles, optimise=False, randomise_tas=False) + mol = set_atom_positions(mol, coord) + mols.append(copy.deepcopy(mol)) + return mols + + +def set_atom_positions(mol, atom_positions): + """ + mol: rdkit mol with a single embeded conformer + atom_positions: 2D np.array of shape [n_atoms, 3] + """ + conf = mol.GetConformers()[0] + for idx, pos in enumerate(atom_positions): + conf.SetAtomPosition(idx, Point3D(*pos)) + return mol + + +def gen_multiple_conf_rdkit_cluster(smiles, n_confs): + M = min(10 * n_confs, 2000) + mols = clustering(smiles, N=n_confs, M=M) + return mols + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_base_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples", + ) + parser.add_argument("--method", type=str, default="rdkit") + parser.add_argument( + "--target_smiles", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles/target_smiles_4_initial.csv", + ) + parser.add_argument("--n_confs", type=int, default=None) + args = parser.parse_args() + + output_base_dir = Path(args.output_base_dir) + if not output_base_dir.exists(): + os.mkdir(output_base_dir) + + current_datetime = datetime.now() + timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") + ts = Path(args.target_smiles).name[:-4] + output_dir = output_base_dir / f"{args.method}_samples_{ts}_{timestamp}" + if output_dir.exists(): + print("Output dir already exisits! Exit generation") + else: + os.mkdir(output_dir) + target_smiles = pd.read_csv(args.target_smiles, index_col=0) + for idx, (_, item) in tqdm( + enumerate(target_smiles.iterrows()), total=len(target_smiles) + ): + n_samples = args.n_confs + if n_samples is None: + n_samples = 2 * item.n_confs + print(f"start generating {n_samples} confs") + if args.method == "rdkit": + confs = gen_multiple_conf_rdkit(item.smiles, n_samples, optimise=True) + if args.method == "rdkit_cluster": + confs = gen_multiple_conf_rdkit_cluster(item.smiles, n_samples) + write_conformers( + confs, item.smiles, output_dir, prefix=f"{args.method}_", idx=idx + ) + print("Finished generation, results are in {}".format(output_dir))