From 8396762009f1cd531e6f4bac0835b27193bd2433 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 21 Sep 2023 12:30:55 -0400 Subject: [PATCH 01/22] add geom stats script --- scripts/conformer/geom_stats.py | 116 ++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 scripts/conformer/geom_stats.py diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py new file mode 100644 index 000000000..d5b7fd9f1 --- /dev/null +++ b/scripts/conformer/geom_stats.py @@ -0,0 +1,116 @@ +import argparse +import os +import json +import pickle +import numpy as np +import pandas as pd + +from rdkit import Chem +from pathlib import Path +from tqdm import tqdm + +from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta + +def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): + if summary_file is None: + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + summary_file = json.load(f) + + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + mol = dic['conformers'][conf_idx]['rd_mol'] + return mol + +def get_all_confs_geom(base_path, smiles, summary_file=None): + if summary_file is None: + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + summary_file = json.load(f) + try: + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + conformers = [x['rd_mol'] for x in dic['conformers']] + return conformers + except KeyError: + print('No pickle_path file for {}'.format(smiles)) + +def get_rd_mol(smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + return mol + +def has_same_can_smiles(mol1, mol2): + sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1)) + sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2)) + return sm1 == sm2 + +def all_same_graphs(mols): + ref = mols[0] + same = [] + for mol in mols: + same.append(has_same_can_smiles(ref, mol)) + return np.all(same) + +def has_hydrogen_tas(mol): + tas = get_rotatable_ta_list(mol) + hydrogen_flags = [] + for t in tas: + hydrogen_flags.append(is_hydrogen_ta(mol, t)) + return np.any(hydrogen_flags) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') + parser.add_argument('--output_dir', type=str, default='./') + args = parser.parse_args() + + base_path = Path(args.geom_dir) + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + drugs_summ = json.load(f) + + + smiles = [] + self_consistent = [] + rdkit_consistent = [] + n_tas_geom = [] + n_tas_rdkit = [] + unique_confs = [] + n_atoms = [] + n_heavy_atoms = [] + hydrogen_tas = [] + for sm, sub_dic in tqdm(list(drugs_summ.items())): + confs = get_all_confs_geom(base_path, sm, summary_file=drugs_summ) + if not confs is None: + rd_mol = get_rd_mol(sm) + smiles.append(sm) + unique_confs.append(len(confs)) + n_atoms.append(confs[0].GetNumAtoms()) + n_heavy_atoms.append(confs[0].GetNumHeavyAtoms()) + self_consistent.append(all_same_graphs(confs)) + rdkit_consistent.append(all_same_graphs(confs + [rd_mol])) + n_tas_geom.append(len(get_rotatable_ta_list(confs[0]))) + n_tas_rdkit.append(len(get_rotatable_ta_list(rd_mol))) + hydrogen_tas.append(has_hydrogen_tas(confs[0])) + + data = { + 'smiles': smiles, + 'self_consistent': self_consistent, + 'rdkit_consistent': rdkit_consistent, + 'n_rotatable_torsion_angles_geom' : n_tas_geom, + 'n_rotatable_torsion_angles_rdkit' : n_tas_rdkit, + 'has_hydrogen_tas': hydrogen_tas, + 'n_confs': unique_confs, + 'n_heavy_atoms': n_heavy_atoms, + 'n_atoms': n_atoms, +} +df = pd.DataFrame(data) +df.to_csv(os.path.join(args.output_dir, 'geom_stats.csv')) + + From 2879ae137ff7f11693b29d35219e4461e3c32968 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 21 Sep 2023 12:36:17 -0400 Subject: [PATCH 02/22] add comment --- scripts/conformer/geom_stats.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index d5b7fd9f1..33058dac5 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -11,6 +11,11 @@ from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta +""" +Here we use rdkit_folder format of the GEOM dataset +Tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master +""" + def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): if summary_file is None: drugs_file = base_path / 'rdkit_folder/summary_drugs.json' From 44077aee58c10adc7a47b6a722f32d0389e7f204 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 21 Sep 2023 17:21:07 -0400 Subject: [PATCH 03/22] fix bug --- scripts/conformer/geom_stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index 33058dac5..5e7640261 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -114,8 +114,8 @@ def has_hydrogen_tas(mol): 'n_confs': unique_confs, 'n_heavy_atoms': n_heavy_atoms, 'n_atoms': n_atoms, -} -df = pd.DataFrame(data) -df.to_csv(os.path.join(args.output_dir, 'geom_stats.csv')) + } + df = pd.DataFrame(data) + df.to_csv(os.path.join(args.output_dir, 'geom_stats.csv')) From 13a739cb4685fe0466fcf1de7b9e4f184243cab2 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 21 Sep 2023 20:52:20 -0400 Subject: [PATCH 04/22] regrouped functions for geom --- gflownet/utils/molecule/geom.py | 56 ++++++++++++++++++++ gflownet/utils/molecule/rotatable_bonds.py | 7 +++ scripts/conformer/geom_stats.py | 61 ++-------------------- 3 files changed, 67 insertions(+), 57 deletions(-) create mode 100644 gflownet/utils/molecule/geom.py diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py new file mode 100644 index 000000000..ec80193c0 --- /dev/null +++ b/gflownet/utils/molecule/geom.py @@ -0,0 +1,56 @@ +import os +import json +import pickle +import numpy as np +import pandas as pd + +from rdkit import Chem +from pathlib import Path +from tqdm import tqdm + +from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta + +def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): + if summary_file is None: + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + summary_file = json.load(f) + + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + mol = dic['conformers'][conf_idx]['rd_mol'] + return mol + +def get_all_confs_geom(base_path, smiles, summary_file=None): + if summary_file is None: + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + summary_file = json.load(f) + try: + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + if os.path.isfile(pickle_path): + with open(pickle_path, "rb") as f: + dic = pickle.load(f) + conformers = [x['rd_mol'] for x in dic['conformers']] + return conformers + except KeyError: + print('No pickle_path file for {}'.format(smiles)) + +def get_rd_mol(smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + return mol + +def has_same_can_smiles(mol1, mol2): + sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1)) + sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2)) + return sm1 == sm2 + +def all_same_graphs(mols): + ref = mols[0] + same = [] + for mol in mols: + same.append(has_same_can_smiles(ref, mol)) + return np.all(same) \ No newline at end of file diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 4dc73a590..fc7579306 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -111,3 +111,10 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id): first = is_connected_to_three_hydrogens(mol, ta[1], ta[2]) second = is_connected_to_three_hydrogens(mol, ta[2], ta[1]) return first or second + +def has_hydrogen_tas(mol): + tas = get_rotatable_ta_list(mol) + hydrogen_flags = [] + for t in tas: + hydrogen_flags.append(is_hydrogen_ta(mol, t)) + return np.any(hydrogen_flags) \ No newline at end of file diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index 5e7640261..d33156e89 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -9,70 +9,17 @@ from pathlib import Path from tqdm import tqdm -from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta +from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, has_hydrogen_tas +from gflownet.utils.molecule.geom import get_conf_geom, get_all_confs_geom, get_rd_mol, all_same_graphs """ Here we use rdkit_folder format of the GEOM dataset Tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master """ - -def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): - if summary_file is None: - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' - with open(drugs_file, "r") as f: - summary_file = json.load(f) - - pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] - if os.path.isfile(pickle_path): - with open(pickle_path, "rb") as f: - dic = pickle.load(f) - mol = dic['conformers'][conf_idx]['rd_mol'] - return mol - -def get_all_confs_geom(base_path, smiles, summary_file=None): - if summary_file is None: - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' - with open(drugs_file, "r") as f: - summary_file = json.load(f) - try: - pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] - if os.path.isfile(pickle_path): - with open(pickle_path, "rb") as f: - dic = pickle.load(f) - conformers = [x['rd_mol'] for x in dic['conformers']] - return conformers - except KeyError: - print('No pickle_path file for {}'.format(smiles)) - -def get_rd_mol(smiles): - mol = Chem.MolFromSmiles(smiles) - mol = Chem.AddHs(mol) - return mol - -def has_same_can_smiles(mol1, mol2): - sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1)) - sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2)) - return sm1 == sm2 - -def all_same_graphs(mols): - ref = mols[0] - same = [] - for mol in mols: - same.append(has_same_can_smiles(ref, mol)) - return np.all(same) - -def has_hydrogen_tas(mol): - tas = get_rotatable_ta_list(mol) - hydrogen_flags = [] - for t in tas: - hydrogen_flags.append(is_hydrogen_ta(mol, t)) - return np.any(hydrogen_flags) - - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') - parser.add_argument('--output_dir', type=str, default='./') + parser.add_argument('--output_file', type=str, default='./geom_stats.csv') args = parser.parse_args() base_path = Path(args.geom_dir) @@ -116,6 +63,6 @@ def has_hydrogen_tas(mol): 'n_atoms': n_atoms, } df = pd.DataFrame(data) - df.to_csv(os.path.join(args.output_dir, 'geom_stats.csv')) + df.to_csv(args.output_file) From ae82682c61bd0ca90c1f2b4fdaeac0c389228246 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 21 Sep 2023 22:22:46 -0400 Subject: [PATCH 05/22] add metrics script --- gflownet/utils/molecule/metrics.py | 36 +++++++++++++++++++ scripts/conformer/compute_metrics.py | 53 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 gflownet/utils/molecule/metrics.py create mode 100644 scripts/conformer/compute_metrics.py diff --git a/gflownet/utils/molecule/metrics.py b/gflownet/utils/molecule/metrics.py new file mode 100644 index 000000000..bf713b70a --- /dev/null +++ b/gflownet/utils/molecule/metrics.py @@ -0,0 +1,36 @@ +# some functions inspired by: https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + +import numpy as np +import pandas as pd +import copy + +from rdkit import Chem + +from rdkit.Chem import rdMolAlign as MA +from rdkit import Chem +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule +from rdkit.Geometry.rdGeometry import Point3D + +def get_best_rmsd(gen_mol, ref_mol): + gen_mol = Chem.RemoveHs(gen_mol) + ref_mol = Chem.RemoveHs(ref_mol) + rmsd = MA.GetBestRMS(gen_mol, ref_mol) + return rmsd + +def get_cov_mat(ref_mols, gen_mols, threshold=1.25): + rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32) + for i, gen_mol in enumerate(gen_mols): + gen_mol_c = copy.deepcopy(gen_mol) + for j, ref_mol in enumerate(ref_mols): + ref_mol_c = copy.deepcopy(ref_mol) + rmsd_mat[j, i] = get_best_rmsd(gen_mol_c, ref_mol_c) + rmsd_mat_min = rmsd_mat.min(-1) + return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean() + +def normalise_positions(mol): + conf = mol.GetConformer() + pos = conf.GetPositions() + pos = pos - pos.mean(axis=0) + for idx, p in enumerate(pos): + conf.SetAtomPosition(idx, Point3D(*p)) + return mol \ No newline at end of file diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py new file mode 100644 index 000000000..35ea59c85 --- /dev/null +++ b/scripts/conformer/compute_metrics.py @@ -0,0 +1,53 @@ +import argparse +import os +import json +import pickle +import numpy as np +import pandas as pd +from pathlib import Path +from tqdm import tqdm + +from gflownet.utils.molecule.metrics import get_cov_mat +from gflownet.utils.molecule.geom import get_all_confs_geom + + +def main(args): + base_path = Path(args.geom_dir) + drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + with open(drugs_file, "r") as f: + drugs_summ = json.load(f) + + gen_files = [Path(args.gen_dir) / x for x in os.listdir(args.gen_dir)] + smiles = [x.name.split('_')[-1][:-4] for x in gen_files] + ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + gen_confs = [] + for fp in gen_files: + with open(fp, 'rb') as f: + gen_confs.append(pickle.load(f)['conformer']) + + cov_list = [] + mat_list = [] + n_conf = [] + for ref, gen in tqdm(zip(ref_confs, gen_confs)): + cov, mat = get_cov_mat(ref, gen[:2*len(ref)], threshold=1.25) + cov_list.append(cov) + mat_list.append(mat) + n_conf.append(len(ref)) + + data = { + 'smiles': smiles, + 'cov': cov_list, + 'mat': mat_list, + 'n_ref_confs': n_conf + } + df = pd.DataFrame(data) + df.to_csv(args.output_file, index=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') + parser.add_argument('--output_file', type=str, default='./gfn_metrics.csv') + parser.add_argument('--gen_dir', type=str, default='./') + args = parser.parse_args() + main(args) \ No newline at end of file From 50c64f2e595b5422bca412886eb7250b55279593 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Fri, 22 Sep 2023 17:56:24 -0400 Subject: [PATCH 06/22] add rdkit baseline generation --- scripts/conformer/rdkit_baselines.py | 56 ++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 scripts/conformer/rdkit_baselines.py diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py new file mode 100644 index 000000000..3c83f8a3d --- /dev/null +++ b/scripts/conformer/rdkit_baselines.py @@ -0,0 +1,56 @@ +import argparse +import numpy as np +import pickle +import os + +from pathlib import Path +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdMolAlign as MA +from tqdm import tqdm + +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule + +def gen_multiple_conf_rdkit(smiles, n_confs=1000): + mols = [] + for _ in range(n_confs): + mols.append(get_single_conf_rdkit(smiles)) + return mols + +def get_single_conf_rdkit(smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + try: + AllChem.MMFFOptimizeMolecule(mol, confId=0) + except Exception as e: + print(e) + return mol + +def write_conformers(conf_list, smiles, output_dir, prefix=''): + conf_dict = {'conformer': conf_list} + filename = output_dir / f'{prefix}conformer_{smiles}.pkl' + with open(filename, 'wb') as file: + pickle.dump(conf_dict, file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, default='./samples') + parser.add_argument('--method', type=str, default='rdkit') + parser.add_argument('--target_smiles', type=str, + default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles.pkl') + parser.add_argument('--n_confs', type=int, default=300) + args = parser.parse_args() + + output_dir = Path(args.output_dir) + if output_dir.exists(): + print("Output dir already exists! Exit generation") + else: + os.mkdir(output_dir) + with open(args.target_smiles, 'rb') as file: + smiles_list = pickle.load(file) + if args.method == 'rdkit': + for smiles in tqdm(smiles_list): + confs = gen_multiple_conf_rdkit(smiles, args.n_confs) + write_conformers(confs, smiles, output_dir, prefix='rdkit_') From 22beb23e348f6d0671d2cd78c5d0bbfc64f0ddbb Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Sun, 24 Sep 2023 00:22:58 -0400 Subject: [PATCH 07/22] progress with clustering --- scripts/conformer/rdkit_baselines.py | 154 ++++++++++++++++++++++++--- 1 file changed, 142 insertions(+), 12 deletions(-) diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py index 3c83f8a3d..e705610de 100644 --- a/scripts/conformer/rdkit_baselines.py +++ b/scripts/conformer/rdkit_baselines.py @@ -2,29 +2,43 @@ import numpy as np import pickle import os +import copy +from scipy.spatial.transform import Rotation +from sklearn.cluster import KMeans from pathlib import Path from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import rdMolAlign as MA +from rdkit.Chem import rdMolTransforms +from rdkit.Geometry.rdGeometry import Point3D from tqdm import tqdm from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule -def gen_multiple_conf_rdkit(smiles, n_confs=1000): +def gen_multiple_conf_rdkit(smiles, n_confs, optimise=True, randomise_tas=False): mols = [] for _ in range(n_confs): - mols.append(get_single_conf_rdkit(smiles)) + mols.append(get_single_conf_rdkit(smiles, optimise=optimise, + randomise_tas=randomise_tas)) return mols -def get_single_conf_rdkit(smiles): - mol = Chem.MolFromSmiles(smiles) +def get_single_conf_rdkit(smiles, optimise=True, randomise_tas=False): + mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) AllChem.EmbedMolecule(mol) - try: - AllChem.MMFFOptimizeMolecule(mol, confId=0) - except Exception as e: - print(e) + if optimise: + try: + AllChem.MMFFOptimizeMolecule(mol, confId=0) + except Exception as e: + print(e) + if randomise_tas: + rotable_bonds = get_torsions(mol) + values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds)) + conf = mol.GetConformers()[0] + for rb, val in zip(rotable_bonds, values): + rdMolTransforms.SetDihedralRad(conf, rb[0], rb[1], rb[2], rb[3], val) + Chem.rdMolTransforms.CanonicalizeConformer(conf) return mol def write_conformers(conf_list, smiles, output_dir, prefix=''): @@ -33,6 +47,120 @@ def write_conformers(conf_list, smiles, output_dir, prefix=''): with open(filename, 'wb') as file: pickle.dump(conf_dict, file) +def get_torsions(m): + # taken from https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + m = Chem.RemoveHs(copy.deepcopy(m)) + torsionList = [] + torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" + torsionQuery = Chem.MolFromSmarts(torsionSmarts) + matches = m.GetSubstructMatches(torsionQuery) + for match in matches: + idx2 = match[0] + idx3 = match[1] + bond = m.GetBondBetweenAtoms(idx2, idx3) + jAtom = m.GetAtomWithIdx(idx2) + kAtom = m.GetAtomWithIdx(idx3) + for b1 in jAtom.GetBonds(): + if b1.GetIdx() == bond.GetIdx(): + continue + idx1 = b1.GetOtherAtomIdx(idx2) + for b2 in kAtom.GetBonds(): + if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()): + continue + idx4 = b2.GetOtherAtomIdx(idx3) + # skip 3-membered rings + if idx4 == idx1: + continue + # skip torsions that include hydrogens + if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or ( + m.GetAtomWithIdx(idx4).GetAtomicNum() == 1 + ): + continue + if m.GetAtomWithIdx(idx4).IsInRing(): + torsionList.append((idx4, idx3, idx2, idx1)) + break + else: + torsionList.append((idx1, idx2, idx3, idx4)) + break + break + return torsionList + + +def clustering(smiles, M=1000, N=100): + # adapted from https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c + total_sz = 0 + rdkit_coords_list = [] + + # add MMFF optimize conformers, 20x + rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=M, + optimise=True, randomise_tas=False) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + # normalize + tgt_coords = rdkit_mols[0].GetConformers()[0].GetPositions().astype(np.float32) + tgt_coords = tgt_coords - np.mean(tgt_coords, axis=0) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # add no MMFF optimize conformers, 5x + rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=int(M // 4), + optimise=False, randomise_tas=False) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + ### add uniform rotation bonds conformers, 5x + rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=int(M // 4), + optimise=False, randomise_tas=True) + rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] + sz = len(rdkit_mols) + + for item in rdkit_mols: + _coords = item.GetConformers()[0].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # clustering + rdkit_coords_flatten = np.array(rdkit_coords_list).reshape(total_sz, -1) + cluster_size = N + kmeans = KMeans(n_clusters=cluster_size, random_state=42).fit(rdkit_coords_flatten) + ids = kmeans.predict(rdkit_coords_flatten) + # get cluster center + center_coords = kmeans.cluster_centers_ + coords_list = [center_coords[i].reshape(-1,3) for i in range(cluster_size)] + mols = [] + for coord in coords_list: + mol = get_single_conf_rdkit(smiles, optimise=False, randomise_tas=False) + mol = set_atom_positions(mol, coord) + mols.append(copy.deepcopy(mol)) + return mols + +def set_atom_positions(mol, atom_positions): + """ + mol: rdkit mol with a single embeded conformer + atom_positions: 2D np.array of shape [n_atoms, 3] + """ + conf = mol.GetConformers()[0] + for idx, pos in enumerate(atom_positions): + conf.SetAtomPosition(idx, Point3D(*pos)) + return mol + +def gen_multiple_conf_rdkit_cluster(smiles, n_confs): + pass + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -50,7 +178,9 @@ def write_conformers(conf_list, smiles, output_dir, prefix=''): os.mkdir(output_dir) with open(args.target_smiles, 'rb') as file: smiles_list = pickle.load(file) - if args.method == 'rdkit': - for smiles in tqdm(smiles_list): - confs = gen_multiple_conf_rdkit(smiles, args.n_confs) - write_conformers(confs, smiles, output_dir, prefix='rdkit_') + for smiles in tqdm(smiles_list): + if args.method == 'rdkit': + confs = gen_multiple_conf_rdkit(smiles, args.n_confs, optimise=True) + if args.method == 'rdkit_cluster': + confs = gen_multiple_conf_rdkit_cluster(smiles, args.n_confs) + write_conformers(confs, smiles, output_dir, prefix=f'{args.method}_') From d1efce0736de7bbce2437f4c38aae7490bbbe74f Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Sun, 24 Sep 2023 21:55:30 -0400 Subject: [PATCH 08/22] finilise metrics and rdk clustering --- scripts/conformer/compute_metrics.py | 53 ++++++++++++++++++++++------ scripts/conformer/rdkit_baselines.py | 39 ++++++++++++-------- 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py index 35ea59c85..5730d4415 100644 --- a/scripts/conformer/compute_metrics.py +++ b/scripts/conformer/compute_metrics.py @@ -17,37 +17,70 @@ def main(args): with open(drugs_file, "r") as f: drugs_summ = json.load(f) - gen_files = [Path(args.gen_dir) / x for x in os.listdir(args.gen_dir)] - smiles = [x.name.split('_')[-1][:-4] for x in gen_files] - ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + filenames = [Path(args.gen_dir) / x for x in os.listdir(args.gen_dir)] + gen_files = [] gen_confs = [] - for fp in gen_files: + smiles = [] + for fp in filenames: with open(fp, 'rb') as f: - gen_confs.append(pickle.load(f)['conformer']) + gen_files.append(pickle.load(f)) + gen_confs.append(gen_files[-1]['conformer']) + if 'smiles' in gen_files[-1].keys(): + smiles.append(gen_files[-1]['smiles']) + if len(smiles) == 0: + smiles = [x.name.split('_')[-1][:-4] for x in filenames] + print('All smiles') + print(*smiles, sep='\n') + ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + geom_stats = pd.read_csv(args.geom_stats, index_col=0) cov_list = [] mat_list = [] n_conf = [] - for ref, gen in tqdm(zip(ref_confs, gen_confs)): - cov, mat = get_cov_mat(ref, gen[:2*len(ref)], threshold=1.25) + n_tas = [] + consistent = [] + has_h_tas = [] + for ref, gen, sm in tqdm(zip(ref_confs, gen_confs, smiles), total=len(ref_confs)): + if len(gen) < 2*len(ref): + print("Recieved less samples that needed for computing metrics! Return nans") + cov, mat = None, None + else: + try: + cov, mat = get_cov_mat(ref, gen[:2*len(ref)], threshold=1.25) + except RuntimeError as e: + print(e) + cov, mat = None, None cov_list.append(cov) mat_list.append(mat) n_conf.append(len(ref)) + n_tas.append(geom_stats[geom_stats.smiles == sm].n_rotatable_torsion_angles_rdkit.values[0]) + consistent.append(geom_stats[geom_stats.smiles == sm].rdkit_consistent.values[0]) + has_h_tas.append(geom_stats[geom_stats.smiles == sm].has_hydrogen_tas.values[0]) data = { 'smiles': smiles, 'cov': cov_list, 'mat': mat_list, - 'n_ref_confs': n_conf + 'n_ref_confs': n_conf, + 'n_tas': n_tas, + 'has_hydrogen_tas': has_h_tas, + 'consistent': consistent } df = pd.DataFrame(data) - df.to_csv(args.output_file, index=False) + name = Path(args.gen_dir).name + if name in ['xtb', 'gfn-ff', 'torchani']: + name = Path(args.gen_dir).name + name = Path(args.gen_dir).parent.name + '_' + name + output_file = Path(args.output_dir) / '{}_metrics.csv'.format(name) + df.to_csv(output_file, index=False) + print('Saved metrics at {}'.format(output_file)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') - parser.add_argument('--output_file', type=str, default='./gfn_metrics.csv') + parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics') + parser.add_argument('--geom_stats', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/geom_stats.csv') parser.add_argument('--gen_dir', type=str, default='./') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py index e705610de..2490bc41e 100644 --- a/scripts/conformer/rdkit_baselines.py +++ b/scripts/conformer/rdkit_baselines.py @@ -3,7 +3,9 @@ import pickle import os import copy +import pandas as pd +from datetime import datetime from scipy.spatial.transform import Rotation from sklearn.cluster import KMeans from pathlib import Path @@ -41,9 +43,9 @@ def get_single_conf_rdkit(smiles, optimise=True, randomise_tas=False): Chem.rdMolTransforms.CanonicalizeConformer(conf) return mol -def write_conformers(conf_list, smiles, output_dir, prefix=''): - conf_dict = {'conformer': conf_list} - filename = output_dir / f'{prefix}conformer_{smiles}.pkl' +def write_conformers(conf_list, smiles, output_dir, prefix='', idx=None): + conf_dict = {'conformer': conf_list, 'smiles': smiles} + filename = output_dir / f'{prefix}conformer_{idx}.pkl' with open(filename, 'wb') as file: pickle.dump(conf_dict, file) @@ -159,28 +161,37 @@ def set_atom_positions(mol, atom_positions): return mol def gen_multiple_conf_rdkit_cluster(smiles, n_confs): - pass + M = min(10 * n_confs, 2000) + mols = clustering(smiles, N=n_confs, M=M) + return mols if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--output_dir', type=str, default='./samples') + parser.add_argument('--output_base_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples') parser.add_argument('--method', type=str, default='rdkit') parser.add_argument('--target_smiles', type=str, - default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles.pkl') + default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles/target_smiles_4_initial.csv') parser.add_argument('--n_confs', type=int, default=300) args = parser.parse_args() - output_dir = Path(args.output_dir) + output_base_dir = Path(args.output_base_dir) + if not output_base_dir.exists(): + os.mkdir(output_base_dir) + + current_datetime = datetime.now() + timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") + ts = Path(args.target_smiles).name[:-4] + output_dir = output_base_dir / f'{args.method}_samples_{ts}_{timestamp}' if output_dir.exists(): - print("Output dir already exists! Exit generation") + print("Output dir already exisits! Exit generation") else: os.mkdir(output_dir) - with open(args.target_smiles, 'rb') as file: - smiles_list = pickle.load(file) - for smiles in tqdm(smiles_list): + target_smiles = pd.read_csv(args.target_smiles, index_col=0) + for idx, (_, item) in tqdm(enumerate(target_smiles.iterrows()), total=len(target_smiles)): if args.method == 'rdkit': - confs = gen_multiple_conf_rdkit(smiles, args.n_confs, optimise=True) + confs = gen_multiple_conf_rdkit(item.smiles, 2 * item.n_confs, optimise=True) if args.method == 'rdkit_cluster': - confs = gen_multiple_conf_rdkit_cluster(smiles, args.n_confs) - write_conformers(confs, smiles, output_dir, prefix=f'{args.method}_') + confs = gen_multiple_conf_rdkit_cluster(item.smiles, 2 * item.n_confs) + write_conformers(confs, item.smiles, output_dir, prefix=f'{args.method}_', idx=idx) + print("Finished generation, results are in {}".format(output_dir)) From 074266108248719ed0b6fb6f5c7db6ecf4da12f9 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 26 Sep 2023 16:59:08 -0400 Subject: [PATCH 09/22] recent changes in metrics, top k, hack, etc --- scripts/conformer/compute_metrics.py | 83 ++++++++++++++++++++++++---- scripts/conformer/merge_metrics.py | 49 ++++++++++++++++ scripts/conformer/rdkit_baselines.py | 10 +++- 3 files changed, 127 insertions(+), 15 deletions(-) create mode 100644 scripts/conformer/merge_metrics.py diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py index 5730d4415..31d58e56d 100644 --- a/scripts/conformer/compute_metrics.py +++ b/scripts/conformer/compute_metrics.py @@ -6,10 +6,40 @@ import pandas as pd from pathlib import Path from tqdm import tqdm +from copy import deepcopy +import random -from gflownet.utils.molecule.metrics import get_cov_mat +from gflownet.utils.molecule.metrics import get_cov_mat, get_best_rmsd from gflownet.utils.molecule.geom import get_all_confs_geom +def distant_enough(conf, others, delta): + conf = deepcopy(conf) + dist = [] + for c in others: + dist.append(get_best_rmsd(deepcopy(c), conf)) + dist = np.array(dist) + return np.all(dist > delta) + + +def get_diverse_top_k(confs_list, k=None, delta=1.25): + """confs_list should be sorted according to the energy! lowest energy first""" + result = [confs_list[0]] + for conf in confs_list[1:]: + if distant_enough(conf, result, delta): + result.append(conf) + if len(result) >= k: + break + if len(result) < k: + print(f"Cannot find {k} different conformations in {len(confs_list)} for delta {delta}") + print(f"Found only {len(result)}. Adding random from generated") + result += random.sample(confs_list.tolist(), k - len(result)) + return result + +def get_smiles_from_filename(filename): + smiles = filename.split('_')[1] + if smiles.endswith('.pkl'): + smiles = smiles[:-4] + return smiles def main(args): base_path = Path(args.geom_dir) @@ -21,18 +51,33 @@ def main(args): gen_files = [] gen_confs = [] smiles = [] + energies = [] for fp in filenames: with open(fp, 'rb') as f: gen_files.append(pickle.load(f)) - gen_confs.append(gen_files[-1]['conformer']) + # 1k cut-off to use the same max number of gen samples for all methods + gen_confs.append(gen_files[-1]['conformer'][:1000]) if 'smiles' in gen_files[-1].keys(): smiles.append(gen_files[-1]['smiles']) + if 'energy' in gen_files[-1].keys(): + energies.append(gen_files[-1]['energy'][:1000]) if len(smiles) == 0: - smiles = [x.name.split('_')[-1][:-4] for x in filenames] + smiles = [get_smiles_from_filename(x.name) for x in filenames] print('All smiles') print(*smiles, sep='\n') ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + if args.use_top_k: + if len(energies) == 0: + raise Exception("Cannot use top-k without energies") + energies = [np.array(e) for e in energies] + indecies = [np.argsort(e) for e in energies] + gen_confs = [np.array(x)[idx] for x, idx in zip(gen_confs, indecies)] + if args.diverse: + gen_confs = [get_diverse_top_k(x, k=len(ref)*2, delta=args.delta) for x, ref in zip(gen_confs, ref_confs)] + if not args.hack: + gen_confs = [x[:2*len(ref)] for x, ref in zip(gen_confs, ref_confs)] + geom_stats = pd.read_csv(args.geom_stats, index_col=0) cov_list = [] mat_list = [] @@ -40,16 +85,19 @@ def main(args): n_tas = [] consistent = [] has_h_tas = [] + hack = False for ref, gen, sm in tqdm(zip(ref_confs, gen_confs, smiles), total=len(ref_confs)): if len(gen) < 2*len(ref): - print("Recieved less samples that needed for computing metrics! Return nans") + print("Recieved less samples that needed for computing metrics!") + print(f"Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs") + try: + if len(gen) > 2*len(ref): + hack = True + print(f"Warning! Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs") + cov, mat = get_cov_mat(ref, gen, threshold=1.25) + except RuntimeError as e: + print(e) cov, mat = None, None - else: - try: - cov, mat = get_cov_mat(ref, gen[:2*len(ref)], threshold=1.25) - except RuntimeError as e: - print(e) - cov, mat = None, None cov_list.append(cov) mat_list.append(mat) n_conf.append(len(ref)) @@ -70,7 +118,14 @@ def main(args): name = Path(args.gen_dir).name if name in ['xtb', 'gfn-ff', 'torchani']: name = Path(args.gen_dir).name - name = Path(args.gen_dir).parent.name + '_' + name + name = Path(args.gen_dir).parent.name + '_' + name + if args.use_top_k: + name += '_top_k' + if args.diverse: + name += f'_diverse_{args.delta}' + if hack: + name += '_hacked' + output_file = Path(args.output_dir) / '{}_metrics.csv'.format(name) df.to_csv(output_file, index=False) print('Saved metrics at {}'.format(output_file)) @@ -81,6 +136,10 @@ def main(args): parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics') parser.add_argument('--geom_stats', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/geom_stats.csv') - parser.add_argument('--gen_dir', type=str, default='./') + parser.add_argument('--gen_dir', type=str, default='./') + parser.add_argument('--use_top_k', type=bool, default=False) + parser.add_argument('--diverse', type=bool, default=False) + parser.add_argument('--delta', type=float, default=1.) + parser.add_argument('--hack', type=bool, default=False) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/scripts/conformer/merge_metrics.py b/scripts/conformer/merge_metrics.py new file mode 100644 index 000000000..db3ebc311 --- /dev/null +++ b/scripts/conformer/merge_metrics.py @@ -0,0 +1,49 @@ +import pandas as pd +from pathlib import Path + +# method = "gfn-ff" +# method = "torchani" +method = "xtb" + +base_path = Path('/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics/') +df = pd.read_csv(base_path / f"gfn_conformers_v2_{method}_metrics.csv", index_col="smiles") +dftop = pd.read_csv(base_path / f"gfn_conformers_v2_{method}_top_k_metrics.csv", index_col="smiles") +rd = pd.read_csv( + base_path / "rdkit_samples_target_smiles_first_batch_2023-09-24_21-15-42_metrics.csv", + index_col="smiles", +) +rdc = pd.read_csv( + base_path / "rdkit_cluster_samples_target_smiles_first_batch_2023-09-24_21-16-23_metrics.csv", + index_col="smiles", +) +rd = rd.loc[df.index] +rdc = rdc.loc[df.index] +dftop = dftop.loc[df.index] +df[f"{method}_cov"] = df["cov"] +df[f"{method}_mat"] = df["mat"] +df[f"{method}_tk_cov"] = dftop["cov"] +df[f"{method}_tk_mat"] = dftop["mat"] +df["rdkit_cov"] = rd["cov"] +df["rdkit_mat"] = rd["mat"] +df["rdkit_cl_cov"] = rdc["cov"] +df["rdkit_cl_mat"] = rdc["mat"] +df = df.sort_values("n_tas") +df.to_csv('./merged_metrics_{}.csv'.format(method)) +df = df[ + [ + f"{method}_cov", + f"{method}_mat", + f"{method}_tk_cov", + f"{method}_tk_mat", + "rdkit_cov", + "rdkit_mat", + "rdkit_cl_cov", + "rdkit_cl_mat", + ] +] +print("Mean") +print(df.mean()) +print("Median") +print(df.median()) +print("Var") +print(df.var()) \ No newline at end of file diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py index 2490bc41e..56253aa3e 100644 --- a/scripts/conformer/rdkit_baselines.py +++ b/scripts/conformer/rdkit_baselines.py @@ -172,7 +172,7 @@ def gen_multiple_conf_rdkit_cluster(smiles, n_confs): parser.add_argument('--method', type=str, default='rdkit') parser.add_argument('--target_smiles', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles/target_smiles_4_initial.csv') - parser.add_argument('--n_confs', type=int, default=300) + parser.add_argument('--n_confs', type=int, default=None) args = parser.parse_args() output_base_dir = Path(args.output_base_dir) @@ -189,9 +189,13 @@ def gen_multiple_conf_rdkit_cluster(smiles, n_confs): os.mkdir(output_dir) target_smiles = pd.read_csv(args.target_smiles, index_col=0) for idx, (_, item) in tqdm(enumerate(target_smiles.iterrows()), total=len(target_smiles)): + n_samples = args.n_confs + if n_samples is None: + n_samples = 2 * item.n_confs + print(f"start generating {n_samples} confs") if args.method == 'rdkit': - confs = gen_multiple_conf_rdkit(item.smiles, 2 * item.n_confs, optimise=True) + confs = gen_multiple_conf_rdkit(item.smiles, n_samples, optimise=True) if args.method == 'rdkit_cluster': - confs = gen_multiple_conf_rdkit_cluster(item.smiles, 2 * item.n_confs) + confs = gen_multiple_conf_rdkit_cluster(item.smiles, n_samples) write_conformers(confs, item.smiles, output_dir, prefix=f'{args.method}_', idx=idx) print("Finished generation, results are in {}".format(output_dir)) From 130d903408d59e116dce6fe94ecee9f5f0092975 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 27 Sep 2023 18:26:43 -0400 Subject: [PATCH 10/22] kde plots script initial --- scripts/conformer/kde_plots.py | 131 +++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 scripts/conformer/kde_plots.py diff --git a/scripts/conformer/kde_plots.py b/scripts/conformer/kde_plots.py new file mode 100644 index 000000000..12b913468 --- /dev/null +++ b/scripts/conformer/kde_plots.py @@ -0,0 +1,131 @@ +# IMPORT THIS FIRST!!!!! +from tblite import interface + +import time +import argparse +import pickle +import numpy as np +import os + +from pathlib import Path +from scipy.special import logsumexp +from datetime import datetime +from tqdm import tqdm + + +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.envs.conformers.conformer import Conformer +from gflownet.utils.common import torch2np + +PROXY_DICT = { + 'tblite': TBLiteMoleculeEnergy, + 'xtb': XTBMoleculeEnergy, + 'torchani': TorchANIMoleculeEnergy +} +PROXY_NAME_DICT = { + 'tblite': 'xtb', + 'xtb': 'gfn-ff', + 'torchani': 'torchani' +} + +def get_smiles_and_proxy_class(filename): + sm = filename.split('_')[2] + proxy_name = filename.split('_')[-1][:-4] + proxy_class = PROXY_DICT[proxy_name] + proxy = proxy_class(device="cpu", float_precision=32) + env = Conformer(smiles=sm, n_torsion_angles=2, reward_func="boltzmann", reward_beta=32, proxy=proxy) + # proxy.setup(env) + return sm, PROXY_NAME_DICT[proxy_name], proxy, env + +def load_samples(filename, base_path): + path = base_path / filename + with open(path, 'rb') as f: + data = pickle.load(f) + return data['x'] + +def get_true_kde(env, n_samples, bandwidth=0.1): + x_from_reward = env.sample_from_reward(n_samples=n_samples) + x_from_reward = torch2np(env.statetorch2kde(x_from_reward)) + kde_true = env.fit_kde( + x_from_reward, + kernel='gaussian', + bandwidth=bandwidth) + return kde_true + + +def get_metrics(kde_true, kde_pred, test_samples): + scores_true = kde_true.score_samples(test_samples) + log_density_true = scores_true - logsumexp(scores_true, axis=0) + scores_pred = kde_pred.score_samples(test_samples) + log_density_pred = scores_pred - logsumexp(scores_pred, axis=0) + density_true = np.exp(log_density_true) + density_pred = np.exp(log_density_pred) + # L1 error + l1 = np.abs(density_pred - density_true).mean() + # KL divergence + kl = (density_true * (log_density_true - log_density_pred)).mean() + # Jensen-Shannon divergence + log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) + jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) + jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + return l1, kl, jsd + + +def main(args): + base_path = Path(args.samples_dir) + output_base = Path(args.output_dir) + if not output_base.exists(): + os.mkdir(output_base) + for filename in tqdm(os.listdir(base_path)[:2]): + ct = time.time() + print(f'{datetime.now().strftime("%H-%M-%S")}: Initialising env') + smiles, pr_name, proxy, env = get_smiles_and_proxy_class(filename) + + # create output dir + current_datetime = datetime.now() + timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") + out_name = f'{pr_name}_{smiles}_{args.n_test}_{args.bandwidth}_{timestamp}' + output_dir = output_base / out_name + os.mkdir(output_dir) + print(f'Will save results at {output_dir}') + + samples = load_samples(filename, base_path)[:100] + n_samples = samples.shape[0] + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing true kde') + kde_true = get_true_kde(env, n_samples, bandwidth=args.bandwidth) + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing pred kde') + kde_pred = env.fit_kde(samples, kernel='gaussian', bandwidth=args.bandwidth) + print(f'{datetime.now().strftime("%H-%M-%S")}: Making figures') + fig_pred = env.plot_kde(kde_pred) + fig_pred.savefig(output_dir / f'kde_pred_{out_name}.png', bbox_inches="tight", format='png') + fig_true = env.plot_kde(kde_true) + fig_true.savefig(output_dir / f'kde_true_{out_name}.png', bbox_inches="tight", format='png') + print(f'{datetime.now().strftime("%H-%M-%S")}: Computing metrics') + test_samples = np.array(env.get_grid_terminating_states(args.n_test))[:, :2] + l1, kl, jsd = get_metrics(kde_true, kde_pred, test_samples) + met = { + 'l1': l1, + 'kl': kl, + 'jsd': jsd + } + # write stuff + with open(output_dir / 'metrics.pkl', 'wb') as file: + pickle.dump(met, file) + with open(output_dir / 'kde_true.pkl', 'wb') as file: + pickle.dump(kde_true, file) + with open(output_dir / 'kde_pred.pkl', 'wb') as file: + pickle.dump(kde_pred, file) + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--samples_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/mcmc_samples') + parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/kde_stats/mcmc') + parser.add_argument('--bandwidth', type=float, default=0.1) + parser.add_argument('--n_test', type=int, default=10000) + args = parser.parse_args() + main(args) \ No newline at end of file From 2ab718589861c1777e4430de1f32905ef8c56b06 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 27 Sep 2023 20:25:45 -0400 Subject: [PATCH 11/22] edit kde plots --- scripts/conformer/kde_plots.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/conformer/kde_plots.py b/scripts/conformer/kde_plots.py index 12b913468..4ddab72c8 100644 --- a/scripts/conformer/kde_plots.py +++ b/scripts/conformer/kde_plots.py @@ -35,7 +35,8 @@ def get_smiles_and_proxy_class(filename): proxy_name = filename.split('_')[-1][:-4] proxy_class = PROXY_DICT[proxy_name] proxy = proxy_class(device="cpu", float_precision=32) - env = Conformer(smiles=sm, n_torsion_angles=2, reward_func="boltzmann", reward_beta=32, proxy=proxy) + env = Conformer(smiles=sm, n_torsion_angles=2, reward_func="boltzmann", reward_beta=32, proxy=proxy, + reward_sampling_method='nested') # proxy.setup(env) return sm, PROXY_NAME_DICT[proxy_name], proxy, env @@ -78,7 +79,7 @@ def main(args): output_base = Path(args.output_dir) if not output_base.exists(): os.mkdir(output_base) - for filename in tqdm(os.listdir(base_path)[:2]): + for filename in tqdm(os.listdir(base_path)): ct = time.time() print(f'{datetime.now().strftime("%H-%M-%S")}: Initialising env') smiles, pr_name, proxy, env = get_smiles_and_proxy_class(filename) @@ -91,7 +92,7 @@ def main(args): os.mkdir(output_dir) print(f'Will save results at {output_dir}') - samples = load_samples(filename, base_path)[:100] + samples = load_samples(filename, base_path) n_samples = samples.shape[0] print(f'{datetime.now().strftime("%H-%M-%S")}: Computing true kde') kde_true = get_true_kde(env, n_samples, bandwidth=args.bandwidth) @@ -125,7 +126,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument('--samples_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/mcmc_samples') parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/kde_stats/mcmc') - parser.add_argument('--bandwidth', type=float, default=0.1) + parser.add_argument('--bandwidth', type=float, default=0.15) parser.add_argument('--n_test', type=int, default=10000) args = parser.parse_args() main(args) \ No newline at end of file From 2157983308c81ce8ddffa0f12f6beef858ded62a Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 13 Nov 2023 16:39:57 -0500 Subject: [PATCH 12/22] after-deadline commit --- gflownet/utils/molecule/geom.py | 1 + scripts/conformer/README.md | 22 ++++++++++++++++++++++ scripts/conformer/compute_metrics.py | 8 ++++++++ 3 files changed, 31 insertions(+) create mode 100644 scripts/conformer/README.md diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py index ec80193c0..e5f33b8b9 100644 --- a/gflownet/utils/molecule/geom.py +++ b/gflownet/utils/molecule/geom.py @@ -37,6 +37,7 @@ def get_all_confs_geom(base_path, smiles, summary_file=None): return conformers except KeyError: print('No pickle_path file for {}'.format(smiles)) + return None def get_rd_mol(smiles): mol = Chem.MolFromSmiles(smiles) diff --git a/scripts/conformer/README.md b/scripts/conformer/README.md new file mode 100644 index 000000000..e5be3940c --- /dev/null +++ b/scripts/conformer/README.md @@ -0,0 +1,22 @@ +This folder contains scripts for dealing with GEOM dataset and running RDKit-base baselines. + +## Calculating statistics for GEOM dataset +The script `geom_stats.py` extracts statistical information from molecular conformation data in the GEOM dataset using the RDKit library. The GEOM dataset is expected to be in the "rdkit_folder" format (tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master). This script parses the dataset, calculates various statistics, and outputs the results to a CSV file. + +Statistics collected include: +* SMILES representation of the molecule. +* Whether the molecule is self-consistent, i.e. its conformations in the dataset correspond to the same SMILES. +* Whether the the milecule is consistent with the RDKit, i.e. all conformations in the dataset correspond to the same SMILES and this SMILES is the same as stored in the dataset. +* The number of rotatable torsion angles in the molecular conformation (both from GEOM and RDKit). +* Whether the molecule contains hydrogen torsion angles. +* The total number of unique conformations for the molecule. +* The number of heavy atoms in the molecule. +* The total number of atoms in the molecule. + +### Usage + +You can use the script with the following command-line arguments: + +* `--geom_dir` (optional): Path to the directory containing the GEOM dataset in an rdkit_folder. The default path is '/home/mila/a/alexandra.volokhova/scratch/datasets/geom'. +* `--output_file` (optional): Path to the output CSV file where the statistics will be saved. The default path is './geom_stats.csv'. + diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py index 31d58e56d..2b6a39147 100644 --- a/scripts/conformer/compute_metrics.py +++ b/scripts/conformer/compute_metrics.py @@ -66,6 +66,14 @@ def main(args): print('All smiles') print(*smiles, sep='\n') ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] + # filter out nans + gen_confs = [gen_confs[idx] for idx, val in enumerate(ref_confs) if val is not None] + smiles = [smiles[idx] for idx, val in enumerate(ref_confs) if val is not None] + if len(energies) > 0: + energies = [energies[idx] for idx, val in enumerate(ref_confs) if val is not None] + ref_confs = [val for val in ref_confs if val is not None] + assert len(gen_confs) == len(ref_confs) == len(smiles) + if args.use_top_k: if len(energies) == 0: raise Exception("Cannot use top-k without energies") From 436890f8b601ffae30b1b3c38761592efa35e894 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 13 Nov 2023 16:42:22 -0500 Subject: [PATCH 13/22] black, isort --- gflownet/envs/base.py | 3 +- gflownet/envs/conformers/conformer.py | 1 - gflownet/envs/crystals/composition.py | 7 +- gflownet/envs/crystals/lattice_parameters.py | 14 +- gflownet/envs/ctorus.py | 3 +- gflownet/gflownet.py | 12 +- gflownet/utils/batch.py | 12 +- .../utils/crystals/build_lattice_dicts.py | 19 +-- gflownet/utils/molecule/geom.py | 30 ++-- gflownet/utils/molecule/metrics.py | 13 +- gflownet/utils/molecule/rotatable_bonds.py | 3 +- gflownet/utils/oracle.py | 3 +- playground/botorch/mes_exact_deepKernel.py | 1 - playground/botorch/mes_gp.py | 1 - playground/botorch/mes_gp_debug.py | 7 +- playground/botorch/mes_nn_bao_fix.py | 5 +- playground/botorch/mes_nn_hardcode_gpVal.py | 4 +- playground/botorch/mes_nn_like_gp.py | 5 +- .../mes_nn_like_gp_nondiagonalcovar.py | 5 +- playground/botorch/mes_var_deepKernel.py | 5 +- scripts/conformer/compute_metrics.py | 138 +++++++++++------- scripts/conformer/geom_stats.py | 45 +++--- scripts/conformer/kde_plots.py | 106 +++++++------- scripts/conformer/merge_metrics.py | 27 ++-- scripts/conformer/rdkit_baselines.py | 92 +++++++----- scripts/dav_mp20_stats.py | 3 +- scripts/pyxtal/pyxtal_vs_pymatgen.py | 8 +- .../gflownet/envs/test_lattice_parameters.py | 18 +-- tests/gflownet/envs/test_tree.py | 11 +- .../policy/test_multihead_tree_policy.py | 11 +- .../utils/molecule/test_rotatable_bonds.py | 6 +- .../gflownet/utils/molecule/test_torsions.py | 6 +- tests/gflownet/utils/test_batch.py | 13 +- 33 files changed, 329 insertions(+), 308 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index e0381ed37..b7f59128f 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -15,7 +15,8 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat +from gflownet.utils.common import (copy, set_device, set_float_precision, + tbool, tfloat) CMAP = mpl.colormaps["cividis"] diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 4b831df9a..fe3778879 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -15,7 +15,6 @@ from gflownet.utils.molecule.rdkit_conformer import RDKitConformer from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles - PREDEFINED_SMILES = [ "O=C(c1ccccc1)c1ccc2c(c1)OCCOCCOCCOCCO2", "O=S(=O)(NN=C1CCCCCC1)c1ccc(Cl)cc1", diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 2e1e75240..721f702b7 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -14,11 +14,8 @@ from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( - get_space_group, - space_group_check_compatible, - space_group_lowest_free_wp_multiplicity, - space_group_wyckoff_gcd, -) + get_space_group, space_group_check_compatible, + space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd) class Composition(GFlowNetEnv): diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 957a7e229..e15bfed54 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -9,16 +9,10 @@ from torchtyping import TensorType from gflownet.envs.grid import Grid -from gflownet.utils.crystals.constants import ( - CUBIC, - HEXAGONAL, - LATTICE_SYSTEMS, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, - TRICLINIC, -) +from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, + LATTICE_SYSTEMS, MONOCLINIC, + ORTHORHOMBIC, RHOMBOHEDRAL, + TETRAGONAL, TRICLINIC) class LatticeParameters(Grid): diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index c006428f6..8b6b79d0e 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -9,7 +9,8 @@ import numpy.typing as npt import pandas as pd import torch -from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises +from torch.distributions import (Categorical, MixtureSameFamily, Uniform, + VonMises) from torchtyping import TensorType from gflownet.envs.htorus import HybridTorus diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 4eebc4d6b..64b228acb 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -21,15 +21,9 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer -from gflownet.utils.common import ( - batch_with_rest, - set_device, - set_float_precision, - tbool, - tfloat, - tlong, - torch2np, -) +from gflownet.utils.common import (batch_with_rest, set_device, + set_float_precision, tbool, tfloat, tlong, + torch2np) class GFlowNetAgent: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a35f01ddf..c76ecfa9e 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -7,16 +7,8 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import ( - concat_items, - copy, - extend, - set_device, - set_float_precision, - tbool, - tfloat, - tlong, -) +from gflownet.utils.common import (concat_items, copy, extend, set_device, + set_float_precision, tbool, tfloat, tlong) class Batch: diff --git a/gflownet/utils/crystals/build_lattice_dicts.py b/gflownet/utils/crystals/build_lattice_dicts.py index 65d4a7958..f62c4bd19 100644 --- a/gflownet/utils/crystals/build_lattice_dicts.py +++ b/gflownet/utils/crystals/build_lattice_dicts.py @@ -8,19 +8,12 @@ import numpy as np import yaml -from lattice_constants import ( - CRYSTAL_CLASSES_WIKIPEDIA, - CRYSTAL_LATTICE_SYSTEMS, - CRYSTAL_SYSTEMS, - POINT_SYMMETRIES, - RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA, -) -from pymatgen.symmetry.groups import ( - PointGroup, - SpaceGroup, - SymmetryGroup, - sg_symbol_from_int_number, -) +from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA, + CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS, + POINT_SYMMETRIES, + RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA) +from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, + sg_symbol_from_int_number) N_SPACE_GROUPS = 230 diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py index e5f33b8b9..273f2cdae 100644 --- a/gflownet/utils/molecule/geom.py +++ b/gflownet/utils/molecule/geom.py @@ -1,57 +1,63 @@ -import os import json +import os import pickle +from pathlib import Path + import numpy as np import pandas as pd - from rdkit import Chem -from pathlib import Path from tqdm import tqdm -from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta +from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, + is_hydrogen_ta) + def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): if summary_file is None: - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + drugs_file = base_path / "rdkit_folder/summary_drugs.json" with open(drugs_file, "r") as f: summary_file = json.load(f) - pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"] if os.path.isfile(pickle_path): with open(pickle_path, "rb") as f: dic = pickle.load(f) - mol = dic['conformers'][conf_idx]['rd_mol'] + mol = dic["conformers"][conf_idx]["rd_mol"] return mol + def get_all_confs_geom(base_path, smiles, summary_file=None): if summary_file is None: - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + drugs_file = base_path / "rdkit_folder/summary_drugs.json" with open(drugs_file, "r") as f: summary_file = json.load(f) try: - pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path'] + pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"] if os.path.isfile(pickle_path): with open(pickle_path, "rb") as f: dic = pickle.load(f) - conformers = [x['rd_mol'] for x in dic['conformers']] + conformers = [x["rd_mol"] for x in dic["conformers"]] return conformers except KeyError: - print('No pickle_path file for {}'.format(smiles)) + print("No pickle_path file for {}".format(smiles)) return None + def get_rd_mol(smiles): mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) return mol + def has_same_can_smiles(mol1, mol2): sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1)) sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2)) return sm1 == sm2 + def all_same_graphs(mols): ref = mols[0] same = [] for mol in mols: same.append(has_same_can_smiles(ref, mol)) - return np.all(same) \ No newline at end of file + return np.all(same) diff --git a/gflownet/utils/molecule/metrics.py b/gflownet/utils/molecule/metrics.py index bf713b70a..f52dc9f38 100644 --- a/gflownet/utils/molecule/metrics.py +++ b/gflownet/utils/molecule/metrics.py @@ -1,22 +1,22 @@ # some functions inspired by: https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c -import numpy as np -import pandas as pd import copy +import numpy as np +import pandas as pd from rdkit import Chem - from rdkit.Chem import rdMolAlign as MA -from rdkit import Chem from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule from rdkit.Geometry.rdGeometry import Point3D + def get_best_rmsd(gen_mol, ref_mol): gen_mol = Chem.RemoveHs(gen_mol) ref_mol = Chem.RemoveHs(ref_mol) rmsd = MA.GetBestRMS(gen_mol, ref_mol) return rmsd + def get_cov_mat(ref_mols, gen_mols, threshold=1.25): rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32) for i, gen_mol in enumerate(gen_mols): @@ -27,10 +27,11 @@ def get_cov_mat(ref_mols, gen_mols, threshold=1.25): rmsd_mat_min = rmsd_mat.min(-1) return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean() + def normalise_positions(mol): conf = mol.GetConformer() pos = conf.GetPositions() - pos = pos - pos.mean(axis=0) + pos = pos - pos.mean(axis=0) for idx, p in enumerate(pos): conf.SetAtomPosition(idx, Point3D(*p)) - return mol \ No newline at end of file + return mol diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index fc7579306..4a7c7cfe3 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -112,9 +112,10 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id): second = is_connected_to_three_hydrogens(mol, ta[2], ta[1]) return first or second + def has_hydrogen_tas(mol): tas = get_rotatable_ta_list(mol) hydrogen_flags = [] for t in tas: hydrogen_flags.append(is_hydrogen_ta(mol, t)) - return np.any(hydrogen_flags) \ No newline at end of file + return np.any(hydrogen_flags) diff --git a/gflownet/utils/oracle.py b/gflownet/utils/oracle.py index c4713745a..1fc9c92c6 100644 --- a/gflownet/utils/oracle.py +++ b/gflownet/utils/oracle.py @@ -14,7 +14,8 @@ ) pass try: - from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel + from bbdob import (DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, + WModel) from bbdob.utils import idx2one_hot except: print( diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index b77bb2e89..743b68256 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -8,7 +8,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index b51df0ce6..8afde5dc8 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -6,7 +6,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Branin, Hartmann diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 06c5a3ed6..76af6ff00 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -3,7 +3,6 @@ import gpytorch import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -50,8 +49,10 @@ def forward(self, x): from botorch.models.utils import add_output_dim from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal -from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) +from gpytorch.likelihoods.gaussian_likelihood import \ + FixedNoiseGaussianLikelihood class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index c4f7de6d0..861268aeb 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -56,8 +55,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal - +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index 6320d4f05..42dcbb9b4 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -57,7 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index d0664a342..0b15c98be 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 2c75fd6a4..1d6626b33 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index f712eaaf0..989af46c4 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -10,7 +10,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann @@ -215,9 +214,7 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qLowerBoundMaxValueEntropy, - qMaxValueEntropy, -) + qLowerBoundMaxValueEntropy, qMaxValueEntropy) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True) diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py index 2b6a39147..4b243c811 100644 --- a/scripts/conformer/compute_metrics.py +++ b/scripts/conformer/compute_metrics.py @@ -1,16 +1,18 @@ import argparse -import os import json +import os import pickle +import random +from copy import deepcopy +from pathlib import Path + import numpy as np import pandas as pd -from pathlib import Path from tqdm import tqdm -from copy import deepcopy -import random -from gflownet.utils.molecule.metrics import get_cov_mat, get_best_rmsd from gflownet.utils.molecule.geom import get_all_confs_geom +from gflownet.utils.molecule.metrics import get_best_rmsd, get_cov_mat + def distant_enough(conf, others, delta): conf = deepcopy(conf) @@ -30,20 +32,24 @@ def get_diverse_top_k(confs_list, k=None, delta=1.25): if len(result) >= k: break if len(result) < k: - print(f"Cannot find {k} different conformations in {len(confs_list)} for delta {delta}") + print( + f"Cannot find {k} different conformations in {len(confs_list)} for delta {delta}" + ) print(f"Found only {len(result)}. Adding random from generated") result += random.sample(confs_list.tolist(), k - len(result)) return result + def get_smiles_from_filename(filename): - smiles = filename.split('_')[1] - if smiles.endswith('.pkl'): + smiles = filename.split("_")[1] + if smiles.endswith(".pkl"): smiles = smiles[:-4] return smiles + def main(args): base_path = Path(args.geom_dir) - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + drugs_file = base_path / "rdkit_folder/summary_drugs.json" with open(drugs_file, "r") as f: drugs_summ = json.load(f) @@ -53,24 +59,26 @@ def main(args): smiles = [] energies = [] for fp in filenames: - with open(fp, 'rb') as f: + with open(fp, "rb") as f: gen_files.append(pickle.load(f)) # 1k cut-off to use the same max number of gen samples for all methods - gen_confs.append(gen_files[-1]['conformer'][:1000]) - if 'smiles' in gen_files[-1].keys(): - smiles.append(gen_files[-1]['smiles']) - if 'energy' in gen_files[-1].keys(): - energies.append(gen_files[-1]['energy'][:1000]) - if len(smiles) == 0: + gen_confs.append(gen_files[-1]["conformer"][:1000]) + if "smiles" in gen_files[-1].keys(): + smiles.append(gen_files[-1]["smiles"]) + if "energy" in gen_files[-1].keys(): + energies.append(gen_files[-1]["energy"][:1000]) + if len(smiles) == 0: smiles = [get_smiles_from_filename(x.name) for x in filenames] - print('All smiles') - print(*smiles, sep='\n') + print("All smiles") + print(*smiles, sep="\n") ref_confs = [get_all_confs_geom(base_path, sm, drugs_summ) for sm in smiles] # filter out nans gen_confs = [gen_confs[idx] for idx, val in enumerate(ref_confs) if val is not None] smiles = [smiles[idx] for idx, val in enumerate(ref_confs) if val is not None] if len(energies) > 0: - energies = [energies[idx] for idx, val in enumerate(ref_confs) if val is not None] + energies = [ + energies[idx] for idx, val in enumerate(ref_confs) if val is not None + ] ref_confs = [val for val in ref_confs if val is not None] assert len(gen_confs) == len(ref_confs) == len(smiles) @@ -81,11 +89,13 @@ def main(args): indecies = [np.argsort(e) for e in energies] gen_confs = [np.array(x)[idx] for x, idx in zip(gen_confs, indecies)] if args.diverse: - gen_confs = [get_diverse_top_k(x, k=len(ref)*2, delta=args.delta) for x, ref in zip(gen_confs, ref_confs)] + gen_confs = [ + get_diverse_top_k(x, k=len(ref) * 2, delta=args.delta) + for x, ref in zip(gen_confs, ref_confs) + ] if not args.hack: - gen_confs = [x[:2*len(ref)] for x, ref in zip(gen_confs, ref_confs)] + gen_confs = [x[: 2 * len(ref)] for x, ref in zip(gen_confs, ref_confs)] - geom_stats = pd.read_csv(args.geom_stats, index_col=0) cov_list = [] mat_list = [] @@ -95,13 +105,17 @@ def main(args): has_h_tas = [] hack = False for ref, gen, sm in tqdm(zip(ref_confs, gen_confs, smiles), total=len(ref_confs)): - if len(gen) < 2*len(ref): + if len(gen) < 2 * len(ref): print("Recieved less samples that needed for computing metrics!") - print(f"Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs") + print( + f"Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs" + ) try: - if len(gen) > 2*len(ref): + if len(gen) > 2 * len(ref): hack = True - print(f"Warning! Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs") + print( + f"Warning! Computing metrics with {len(gen)} generated confs for {len(ref)} reference confs" + ) cov, mat = get_cov_mat(ref, gen, threshold=1.25) except RuntimeError as e: print(e) @@ -109,45 +123,63 @@ def main(args): cov_list.append(cov) mat_list.append(mat) n_conf.append(len(ref)) - n_tas.append(geom_stats[geom_stats.smiles == sm].n_rotatable_torsion_angles_rdkit.values[0]) - consistent.append(geom_stats[geom_stats.smiles == sm].rdkit_consistent.values[0]) + n_tas.append( + geom_stats[geom_stats.smiles == sm].n_rotatable_torsion_angles_rdkit.values[ + 0 + ] + ) + consistent.append( + geom_stats[geom_stats.smiles == sm].rdkit_consistent.values[0] + ) has_h_tas.append(geom_stats[geom_stats.smiles == sm].has_hydrogen_tas.values[0]) - + data = { - 'smiles': smiles, - 'cov': cov_list, - 'mat': mat_list, - 'n_ref_confs': n_conf, - 'n_tas': n_tas, - 'has_hydrogen_tas': has_h_tas, - 'consistent': consistent + "smiles": smiles, + "cov": cov_list, + "mat": mat_list, + "n_ref_confs": n_conf, + "n_tas": n_tas, + "has_hydrogen_tas": has_h_tas, + "consistent": consistent, } df = pd.DataFrame(data) name = Path(args.gen_dir).name - if name in ['xtb', 'gfn-ff', 'torchani']: + if name in ["xtb", "gfn-ff", "torchani"]: name = Path(args.gen_dir).name - name = Path(args.gen_dir).parent.name + '_' + name + name = Path(args.gen_dir).parent.name + "_" + name if args.use_top_k: - name += '_top_k' + name += "_top_k" if args.diverse: - name += f'_diverse_{args.delta}' + name += f"_diverse_{args.delta}" if hack: - name += '_hacked' + name += "_hacked" - output_file = Path(args.output_dir) / '{}_metrics.csv'.format(name) + output_file = Path(args.output_dir) / "{}_metrics.csv".format(name) df.to_csv(output_file, index=False) - print('Saved metrics at {}'.format(output_file)) + print("Saved metrics at {}".format(output_file)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') - parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics') - parser.add_argument('--geom_stats', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/geom_stats.csv') - parser.add_argument('--gen_dir', type=str, default='./') - parser.add_argument('--use_top_k', type=bool, default=False) - parser.add_argument('--diverse', type=bool, default=False) - parser.add_argument('--delta', type=float, default=1.) - parser.add_argument('--hack', type=bool, default=False) + parser.add_argument( + "--geom_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/scratch/datasets/geom", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics", + ) + parser.add_argument( + "--geom_stats", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/geom_stats.csv", + ) + parser.add_argument("--gen_dir", type=str, default="./") + parser.add_argument("--use_top_k", type=bool, default=False) + parser.add_argument("--diverse", type=bool, default=False) + parser.add_argument("--delta", type=float, default=1.0) + parser.add_argument("--hack", type=bool, default=False) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index d33156e89..5c6fa1699 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -1,32 +1,37 @@ import argparse -import os import json +import os import pickle +from pathlib import Path + import numpy as np import pandas as pd - from rdkit import Chem -from pathlib import Path from tqdm import tqdm -from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, has_hydrogen_tas -from gflownet.utils.molecule.geom import get_conf_geom, get_all_confs_geom, get_rd_mol, all_same_graphs +from gflownet.utils.molecule.geom import (all_same_graphs, get_all_confs_geom, + get_conf_geom, get_rd_mol) +from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, + has_hydrogen_tas) """ Here we use rdkit_folder format of the GEOM dataset Tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master """ -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--geom_dir', type=str, default='/home/mila/a/alexandra.volokhova/scratch/datasets/geom') - parser.add_argument('--output_file', type=str, default='./geom_stats.csv') + parser.add_argument( + "--geom_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/scratch/datasets/geom", + ) + parser.add_argument("--output_file", type=str, default="./geom_stats.csv") args = parser.parse_args() base_path = Path(args.geom_dir) - drugs_file = base_path / 'rdkit_folder/summary_drugs.json' + drugs_file = base_path / "rdkit_folder/summary_drugs.json" with open(drugs_file, "r") as f: drugs_summ = json.load(f) - smiles = [] self_consistent = [] @@ -52,17 +57,15 @@ hydrogen_tas.append(has_hydrogen_tas(confs[0])) data = { - 'smiles': smiles, - 'self_consistent': self_consistent, - 'rdkit_consistent': rdkit_consistent, - 'n_rotatable_torsion_angles_geom' : n_tas_geom, - 'n_rotatable_torsion_angles_rdkit' : n_tas_rdkit, - 'has_hydrogen_tas': hydrogen_tas, - 'n_confs': unique_confs, - 'n_heavy_atoms': n_heavy_atoms, - 'n_atoms': n_atoms, + "smiles": smiles, + "self_consistent": self_consistent, + "rdkit_consistent": rdkit_consistent, + "n_rotatable_torsion_angles_geom": n_tas_geom, + "n_rotatable_torsion_angles_rdkit": n_tas_rdkit, + "has_hydrogen_tas": hydrogen_tas, + "n_confs": unique_confs, + "n_heavy_atoms": n_heavy_atoms, + "n_atoms": n_atoms, } df = pd.DataFrame(data) df.to_csv(args.output_file) - - diff --git a/scripts/conformer/kde_plots.py b/scripts/conformer/kde_plots.py index 4ddab72c8..9ceed35bf 100644 --- a/scripts/conformer/kde_plots.py +++ b/scripts/conformer/kde_plots.py @@ -1,58 +1,58 @@ # IMPORT THIS FIRST!!!!! -from tblite import interface - -import time import argparse -import pickle -import numpy as np import os - +import pickle +import time +from datetime import datetime from pathlib import Path + +import numpy as np from scipy.special import logsumexp -from datetime import datetime +from tblite import interface from tqdm import tqdm - -from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy -from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy -from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy from gflownet.envs.conformers.conformer import Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy from gflownet.utils.common import torch2np PROXY_DICT = { - 'tblite': TBLiteMoleculeEnergy, - 'xtb': XTBMoleculeEnergy, - 'torchani': TorchANIMoleculeEnergy -} -PROXY_NAME_DICT = { - 'tblite': 'xtb', - 'xtb': 'gfn-ff', - 'torchani': 'torchani' + "tblite": TBLiteMoleculeEnergy, + "xtb": XTBMoleculeEnergy, + "torchani": TorchANIMoleculeEnergy, } +PROXY_NAME_DICT = {"tblite": "xtb", "xtb": "gfn-ff", "torchani": "torchani"} + def get_smiles_and_proxy_class(filename): - sm = filename.split('_')[2] - proxy_name = filename.split('_')[-1][:-4] + sm = filename.split("_")[2] + proxy_name = filename.split("_")[-1][:-4] proxy_class = PROXY_DICT[proxy_name] proxy = proxy_class(device="cpu", float_precision=32) - env = Conformer(smiles=sm, n_torsion_angles=2, reward_func="boltzmann", reward_beta=32, proxy=proxy, - reward_sampling_method='nested') + env = Conformer( + smiles=sm, + n_torsion_angles=2, + reward_func="boltzmann", + reward_beta=32, + proxy=proxy, + reward_sampling_method="nested", + ) # proxy.setup(env) return sm, PROXY_NAME_DICT[proxy_name], proxy, env + def load_samples(filename, base_path): path = base_path / filename - with open(path, 'rb') as f: + with open(path, "rb") as f: data = pickle.load(f) - return data['x'] + return data["x"] + def get_true_kde(env, n_samples, bandwidth=0.1): x_from_reward = env.sample_from_reward(n_samples=n_samples) x_from_reward = torch2np(env.statetorch2kde(x_from_reward)) - kde_true = env.fit_kde( - x_from_reward, - kernel='gaussian', - bandwidth=bandwidth) + kde_true = env.fit_kde(x_from_reward, kernel="gaussian", bandwidth=bandwidth) return kde_true @@ -87,46 +87,52 @@ def main(args): # create output dir current_datetime = datetime.now() timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") - out_name = f'{pr_name}_{smiles}_{args.n_test}_{args.bandwidth}_{timestamp}' + out_name = f"{pr_name}_{smiles}_{args.n_test}_{args.bandwidth}_{timestamp}" output_dir = output_base / out_name os.mkdir(output_dir) - print(f'Will save results at {output_dir}') + print(f"Will save results at {output_dir}") samples = load_samples(filename, base_path) n_samples = samples.shape[0] print(f'{datetime.now().strftime("%H-%M-%S")}: Computing true kde') kde_true = get_true_kde(env, n_samples, bandwidth=args.bandwidth) print(f'{datetime.now().strftime("%H-%M-%S")}: Computing pred kde') - kde_pred = env.fit_kde(samples, kernel='gaussian', bandwidth=args.bandwidth) + kde_pred = env.fit_kde(samples, kernel="gaussian", bandwidth=args.bandwidth) print(f'{datetime.now().strftime("%H-%M-%S")}: Making figures') fig_pred = env.plot_kde(kde_pred) - fig_pred.savefig(output_dir / f'kde_pred_{out_name}.png', bbox_inches="tight", format='png') + fig_pred.savefig( + output_dir / f"kde_pred_{out_name}.png", bbox_inches="tight", format="png" + ) fig_true = env.plot_kde(kde_true) - fig_true.savefig(output_dir / f'kde_true_{out_name}.png', bbox_inches="tight", format='png') + fig_true.savefig( + output_dir / f"kde_true_{out_name}.png", bbox_inches="tight", format="png" + ) print(f'{datetime.now().strftime("%H-%M-%S")}: Computing metrics') test_samples = np.array(env.get_grid_terminating_states(args.n_test))[:, :2] l1, kl, jsd = get_metrics(kde_true, kde_pred, test_samples) - met = { - 'l1': l1, - 'kl': kl, - 'jsd': jsd - } + met = {"l1": l1, "kl": kl, "jsd": jsd} # write stuff - with open(output_dir / 'metrics.pkl', 'wb') as file: + with open(output_dir / "metrics.pkl", "wb") as file: pickle.dump(met, file) - with open(output_dir / 'kde_true.pkl', 'wb') as file: + with open(output_dir / "kde_true.pkl", "wb") as file: pickle.dump(kde_true, file) - with open(output_dir / 'kde_pred.pkl', 'wb') as file: + with open(output_dir / "kde_pred.pkl", "wb") as file: pickle.dump(kde_pred, file) - - -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--samples_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/mcmc_samples') - parser.add_argument('--output_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/kde_stats/mcmc') - parser.add_argument('--bandwidth', type=float, default=0.15) - parser.add_argument('--n_test', type=int, default=10000) + parser.add_argument( + "--samples_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/mcmc_samples", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/kde_stats/mcmc", + ) + parser.add_argument("--bandwidth", type=float, default=0.15) + parser.add_argument("--n_test", type=int, default=10000) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/scripts/conformer/merge_metrics.py b/scripts/conformer/merge_metrics.py index db3ebc311..a73922c2e 100644 --- a/scripts/conformer/merge_metrics.py +++ b/scripts/conformer/merge_metrics.py @@ -1,19 +1,28 @@ -import pandas as pd from pathlib import Path +import pandas as pd + # method = "gfn-ff" # method = "torchani" method = "xtb" -base_path = Path('/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics/') -df = pd.read_csv(base_path / f"gfn_conformers_v2_{method}_metrics.csv", index_col="smiles") -dftop = pd.read_csv(base_path / f"gfn_conformers_v2_{method}_top_k_metrics.csv", index_col="smiles") +base_path = Path( + "/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/metrics/" +) +df = pd.read_csv( + base_path / f"gfn_conformers_v2_{method}_metrics.csv", index_col="smiles" +) +dftop = pd.read_csv( + base_path / f"gfn_conformers_v2_{method}_top_k_metrics.csv", index_col="smiles" +) rd = pd.read_csv( - base_path / "rdkit_samples_target_smiles_first_batch_2023-09-24_21-15-42_metrics.csv", + base_path + / "rdkit_samples_target_smiles_first_batch_2023-09-24_21-15-42_metrics.csv", index_col="smiles", ) rdc = pd.read_csv( - base_path / "rdkit_cluster_samples_target_smiles_first_batch_2023-09-24_21-16-23_metrics.csv", + base_path + / "rdkit_cluster_samples_target_smiles_first_batch_2023-09-24_21-16-23_metrics.csv", index_col="smiles", ) rd = rd.loc[df.index] @@ -22,13 +31,13 @@ df[f"{method}_cov"] = df["cov"] df[f"{method}_mat"] = df["mat"] df[f"{method}_tk_cov"] = dftop["cov"] -df[f"{method}_tk_mat"] = dftop["mat"] +df[f"{method}_tk_mat"] = dftop["mat"] df["rdkit_cov"] = rd["cov"] df["rdkit_mat"] = rd["mat"] df["rdkit_cl_cov"] = rdc["cov"] df["rdkit_cl_mat"] = rdc["mat"] df = df.sort_values("n_tas") -df.to_csv('./merged_metrics_{}.csv'.format(method)) +df.to_csv("./merged_metrics_{}.csv".format(method)) df = df[ [ f"{method}_cov", @@ -46,4 +55,4 @@ print("Median") print(df.median()) print("Var") -print(df.var()) \ No newline at end of file +print(df.var()) diff --git a/scripts/conformer/rdkit_baselines.py b/scripts/conformer/rdkit_baselines.py index 56253aa3e..3606599fe 100644 --- a/scripts/conformer/rdkit_baselines.py +++ b/scripts/conformer/rdkit_baselines.py @@ -1,30 +1,34 @@ import argparse -import numpy as np -import pickle -import os import copy -import pandas as pd - +import os +import pickle from datetime import datetime -from scipy.spatial.transform import Rotation -from sklearn.cluster import KMeans from pathlib import Path + +import numpy as np +import pandas as pd from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import rdMolAlign as MA from rdkit.Chem import rdMolTransforms +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule from rdkit.Geometry.rdGeometry import Point3D +from scipy.spatial.transform import Rotation +from sklearn.cluster import KMeans from tqdm import tqdm -from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule def gen_multiple_conf_rdkit(smiles, n_confs, optimise=True, randomise_tas=False): mols = [] for _ in range(n_confs): - mols.append(get_single_conf_rdkit(smiles, optimise=optimise, - randomise_tas=randomise_tas)) + mols.append( + get_single_conf_rdkit( + smiles, optimise=optimise, randomise_tas=randomise_tas + ) + ) return mols + def get_single_conf_rdkit(smiles, optimise=True, randomise_tas=False): mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) @@ -43,12 +47,14 @@ def get_single_conf_rdkit(smiles, optimise=True, randomise_tas=False): Chem.rdMolTransforms.CanonicalizeConformer(conf) return mol -def write_conformers(conf_list, smiles, output_dir, prefix='', idx=None): - conf_dict = {'conformer': conf_list, 'smiles': smiles} - filename = output_dir / f'{prefix}conformer_{idx}.pkl' - with open(filename, 'wb') as file: + +def write_conformers(conf_list, smiles, output_dir, prefix="", idx=None): + conf_dict = {"conformer": conf_list, "smiles": smiles} + filename = output_dir / f"{prefix}conformer_{idx}.pkl" + with open(filename, "wb") as file: pickle.dump(conf_dict, file) + def get_torsions(m): # taken from https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c m = Chem.RemoveHs(copy.deepcopy(m)) @@ -94,14 +100,15 @@ def clustering(smiles, M=1000, N=100): rdkit_coords_list = [] # add MMFF optimize conformers, 20x - rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=M, - optimise=True, randomise_tas=False) + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=M, optimise=True, randomise_tas=False + ) rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] sz = len(rdkit_mols) # normalize tgt_coords = rdkit_mols[0].GetConformers()[0].GetPositions().astype(np.float32) tgt_coords = tgt_coords - np.mean(tgt_coords, axis=0) - + for item in rdkit_mols: _coords = item.GetConformers()[0].GetPositions().astype(np.float32) _coords = _coords - _coords.mean(axis=0) # need to normalize first @@ -110,8 +117,9 @@ def clustering(smiles, M=1000, N=100): total_sz += sz # add no MMFF optimize conformers, 5x - rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=int(M // 4), - optimise=False, randomise_tas=False) + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=int(M // 4), optimise=False, randomise_tas=False + ) rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] sz = len(rdkit_mols) @@ -123,8 +131,9 @@ def clustering(smiles, M=1000, N=100): total_sz += sz ### add uniform rotation bonds conformers, 5x - rdkit_mols = gen_multiple_conf_rdkit(smiles, n_confs=int(M // 4), - optimise=False, randomise_tas=True) + rdkit_mols = gen_multiple_conf_rdkit( + smiles, n_confs=int(M // 4), optimise=False, randomise_tas=True + ) rdkit_mols = [Chem.RemoveHs(x) for x in rdkit_mols] sz = len(rdkit_mols) @@ -142,7 +151,7 @@ def clustering(smiles, M=1000, N=100): ids = kmeans.predict(rdkit_coords_flatten) # get cluster center center_coords = kmeans.cluster_centers_ - coords_list = [center_coords[i].reshape(-1,3) for i in range(cluster_size)] + coords_list = [center_coords[i].reshape(-1, 3) for i in range(cluster_size)] mols = [] for coord in coords_list: mol = get_single_conf_rdkit(smiles, optimise=False, randomise_tas=False) @@ -150,6 +159,7 @@ def clustering(smiles, M=1000, N=100): mols.append(copy.deepcopy(mol)) return mols + def set_atom_positions(mol, atom_positions): """ mol: rdkit mol with a single embeded conformer @@ -160,42 +170,54 @@ def set_atom_positions(mol, atom_positions): conf.SetAtomPosition(idx, Point3D(*pos)) return mol + def gen_multiple_conf_rdkit_cluster(smiles, n_confs): M = min(10 * n_confs, 2000) mols = clustering(smiles, N=n_confs, M=M) return mols -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--output_base_dir', type=str, default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples') - parser.add_argument('--method', type=str, default='rdkit') - parser.add_argument('--target_smiles', type=str, - default='/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles/target_smiles_4_initial.csv') - parser.add_argument('--n_confs', type=int, default=None) + parser.add_argument( + "--output_base_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples", + ) + parser.add_argument("--method", type=str, default="rdkit") + parser.add_argument( + "--target_smiles", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/target_smiles/target_smiles_4_initial.csv", + ) + parser.add_argument("--n_confs", type=int, default=None) args = parser.parse_args() output_base_dir = Path(args.output_base_dir) if not output_base_dir.exists(): - os.mkdir(output_base_dir) - + os.mkdir(output_base_dir) + current_datetime = datetime.now() timestamp = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") ts = Path(args.target_smiles).name[:-4] - output_dir = output_base_dir / f'{args.method}_samples_{ts}_{timestamp}' + output_dir = output_base_dir / f"{args.method}_samples_{ts}_{timestamp}" if output_dir.exists(): print("Output dir already exisits! Exit generation") else: os.mkdir(output_dir) target_smiles = pd.read_csv(args.target_smiles, index_col=0) - for idx, (_, item) in tqdm(enumerate(target_smiles.iterrows()), total=len(target_smiles)): + for idx, (_, item) in tqdm( + enumerate(target_smiles.iterrows()), total=len(target_smiles) + ): n_samples = args.n_confs if n_samples is None: n_samples = 2 * item.n_confs print(f"start generating {n_samples} confs") - if args.method == 'rdkit': + if args.method == "rdkit": confs = gen_multiple_conf_rdkit(item.smiles, n_samples, optimise=True) - if args.method == 'rdkit_cluster': + if args.method == "rdkit_cluster": confs = gen_multiple_conf_rdkit_cluster(item.smiles, n_samples) - write_conformers(confs, item.smiles, output_dir, prefix=f'{args.method}_', idx=idx) + write_conformers( + confs, item.smiles, output_dir, prefix=f"{args.method}_", idx=idx + ) print("Finished generation, results are in {}".format(output_dir)) diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 3df1c78c9..868c3745e 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -19,7 +19,8 @@ from collections import Counter -from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders +from external.repos.ActiveLearningMaterials.dave.utils.loaders import \ + make_loaders from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 62a226ae7..6ffcbaa21 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -4,12 +4,8 @@ """ from argparse import ArgumentParser -from pymatgen.symmetry.groups import ( - PointGroup, - SpaceGroup, - SymmetryGroup, - sg_symbol_from_int_number, -) +from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, + sg_symbol_from_int_number) from pyxtal.symmetry import Group N_SYMMETRY_GROUPS = 230 diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index 16aea2814..d74ea29bd 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -2,17 +2,13 @@ import pytest import torch -from gflownet.envs.crystals.lattice_parameters import ( - CUBIC, - HEXAGONAL, - LATTICE_SYSTEMS, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, - TRICLINIC, - LatticeParameters, -) +from gflownet.envs.crystals.lattice_parameters import (CUBIC, HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, TRICLINIC, + LatticeParameters) @pytest.fixture() diff --git a/tests/gflownet/envs/test_tree.py b/tests/gflownet/envs/test_tree.py index d21009288..3d7af9c84 100644 --- a/tests/gflownet/envs/test_tree.py +++ b/tests/gflownet/envs/test_tree.py @@ -5,15 +5,8 @@ import pytest import torch -from gflownet.envs.tree import ( - ActionType, - Attribute, - NodeType, - Operator, - Stage, - Status, - Tree, -) +from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator, + Stage, Status, Tree) from gflownet.utils.common import tfloat NAN = float("NaN") diff --git a/tests/gflownet/policy/test_multihead_tree_policy.py b/tests/gflownet/policy/test_multihead_tree_policy.py index a28570a53..5d5448099 100644 --- a/tests/gflownet/policy/test_multihead_tree_policy.py +++ b/tests/gflownet/policy/test_multihead_tree_policy.py @@ -4,13 +4,10 @@ from torch_geometric.data import Batch from gflownet.envs.tree import Attribute, Operator, Tree -from gflownet.policy.multihead_tree import ( - Backbone, - FeatureSelectionHead, - LeafSelectionHead, - OperatorSelectionHead, - ThresholdSelectionHead, -) +from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead, + LeafSelectionHead, + OperatorSelectionHead, + ThresholdSelectionHead) N_OBSERVATIONS = 17 N_FEATURES = 5 diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index b316dc1c0..931bb518b 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -2,10 +2,8 @@ from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import ( - find_rotor_from_smiles, - is_hydrogen_ta, -) +from gflownet.utils.molecule.rotatable_bonds import (find_rotor_from_smiles, + is_hydrogen_ta) def test_simple_ad(): diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index acdeef1db..ef11ca543 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -8,7 +8,8 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values -from gflownet.utils.molecule.torsions import apply_rotations, get_rotation_masks +from gflownet.utils.molecule.torsions import (apply_rotations, + get_rotation_masks) def test_four_nodes_chain(): @@ -147,7 +148,8 @@ def stress_test_apply_rotation_alanine_dipeptide(): from rdkit.Geometry.rdGeometry import Point3D from gflownet.utils.molecule.featurizer import MolDGLFeaturizer - from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + from gflownet.utils.molecule.rdkit_conformer import \ + get_torsion_angles_values mol = Chem.MolFromSmiles(constants.ad_smiles) mol = Chem.AddHs(mol) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 338dfd061..e26dd2455 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -8,16 +8,9 @@ from gflownet.proxy.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch -from gflownet.utils.common import ( - concat_items, - copy, - set_device, - set_float_precision, - tbool, - tfloat, - tint, - tlong, -) +from gflownet.utils.common import (concat_items, copy, set_device, + set_float_precision, tbool, tfloat, tint, + tlong) # Sets the number of repetitions for the tests. Please increase to ~10 after # introducing changes to the Batch class and decrease again to 1 when passed. From 0d785d454d14e14ff1c26e52fe1192ddb6a8413a Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 13 Nov 2023 17:32:14 -0500 Subject: [PATCH 14/22] readme edit --- scripts/conformer/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/conformer/README.md b/scripts/conformer/README.md index e5be3940c..bd87be9f0 100644 --- a/scripts/conformer/README.md +++ b/scripts/conformer/README.md @@ -1,13 +1,13 @@ This folder contains scripts for dealing with GEOM dataset and running RDKit-base baselines. ## Calculating statistics for GEOM dataset -The script `geom_stats.py` extracts statistical information from molecular conformation data in the GEOM dataset using the RDKit library. The GEOM dataset is expected to be in the "rdkit_folder" format (tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master). This script parses the dataset, calculates various statistics, and outputs the results to a CSV file. +The script `geom_stats.py` extracts statistical information from molecular conformation data in the GEOM dataset using the RDKit library. The GEOM dataset is expected to be in the "rdkit_folder" format (tutorial and downloading links are here: https://github.com/learningmatter-mit/geom/tree/master). This script parses the dataset, calculates relevant statistics, and outputs the results to a CSV file. Statistics collected include: * SMILES representation of the molecule. * Whether the molecule is self-consistent, i.e. its conformations in the dataset correspond to the same SMILES. -* Whether the the milecule is consistent with the RDKit, i.e. all conformations in the dataset correspond to the same SMILES and this SMILES is the same as stored in the dataset. -* The number of rotatable torsion angles in the molecular conformation (both from GEOM and RDKit). +* Whether the the molecule is consistent with the RDKit, i.e. all conformations in the dataset correspond to the same SMILES and this SMILES is the same as stored in the dataset. +* The number of rotatable torsion angles in the molecular conformation (both from GEOM conformers and RDKit-generated graph). * Whether the molecule contains hydrogen torsion angles. * The total number of unique conformations for the molecule. * The number of heavy atoms in the molecule. From 8206e0245cc45e6ff1dc803b9e3c4ec00b9145c4 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 13 Nov 2023 18:14:05 -0500 Subject: [PATCH 15/22] black, isort again --- gflownet/envs/base.py | 3 +-- gflownet/envs/crystals/composition.py | 7 +++++-- gflownet/envs/crystals/lattice_parameters.py | 14 ++++++++++---- gflownet/envs/ctorus.py | 3 +-- gflownet/gflownet.py | 12 +++++++++--- gflownet/utils/batch.py | 12 ++++++++++-- .../utils/crystals/build_lattice_dicts.py | 19 +++++++++++++------ gflownet/utils/molecule/geom.py | 6 ++++-- gflownet/utils/oracle.py | 3 +-- playground/botorch/mes_exact_deepKernel.py | 1 + playground/botorch/mes_gp.py | 1 + playground/botorch/mes_gp_debug.py | 7 +++---- playground/botorch/mes_nn_bao_fix.py | 5 +++-- playground/botorch/mes_nn_hardcode_gpVal.py | 4 ++-- playground/botorch/mes_nn_like_gp.py | 5 +++-- .../mes_nn_like_gp_nondiagonalcovar.py | 5 +++-- playground/botorch/mes_var_deepKernel.py | 5 ++++- scripts/conformer/geom_stats.py | 14 ++++++++++---- scripts/dav_mp20_stats.py | 3 +-- scripts/pyxtal/pyxtal_vs_pymatgen.py | 8 ++++++-- .../gflownet/envs/test_lattice_parameters.py | 18 +++++++++++------- tests/gflownet/envs/test_tree.py | 11 +++++++++-- .../policy/test_multihead_tree_policy.py | 11 +++++++---- .../utils/molecule/test_rotatable_bonds.py | 6 ++++-- .../gflownet/utils/molecule/test_torsions.py | 6 ++---- tests/gflownet/utils/test_batch.py | 13 ++++++++++--- 26 files changed, 134 insertions(+), 68 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index b7f59128f..e0381ed37 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -15,8 +15,7 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import (copy, set_device, set_float_precision, - tbool, tfloat) +from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat CMAP = mpl.colormaps["cividis"] diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 721f702b7..2e1e75240 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -14,8 +14,11 @@ from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( - get_space_group, space_group_check_compatible, - space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd) + get_space_group, + space_group_check_compatible, + space_group_lowest_free_wp_multiplicity, + space_group_wyckoff_gcd, +) class Composition(GFlowNetEnv): diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index e15bfed54..957a7e229 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -9,10 +9,16 @@ from torchtyping import TensorType from gflownet.envs.grid import Grid -from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, - LATTICE_SYSTEMS, MONOCLINIC, - ORTHORHOMBIC, RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC) +from gflownet.utils.crystals.constants import ( + CUBIC, + HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, +) class LatticeParameters(Grid): diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 8b6b79d0e..c006428f6 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -9,8 +9,7 @@ import numpy.typing as npt import pandas as pd import torch -from torch.distributions import (Categorical, MixtureSameFamily, Uniform, - VonMises) +from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises from torchtyping import TensorType from gflownet.envs.htorus import HybridTorus diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 64b228acb..4eebc4d6b 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -21,9 +21,15 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer -from gflownet.utils.common import (batch_with_rest, set_device, - set_float_precision, tbool, tfloat, tlong, - torch2np) +from gflownet.utils.common import ( + batch_with_rest, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, + torch2np, +) class GFlowNetAgent: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index c76ecfa9e..a35f01ddf 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -7,8 +7,16 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import (concat_items, copy, extend, set_device, - set_float_precision, tbool, tfloat, tlong) +from gflownet.utils.common import ( + concat_items, + copy, + extend, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, +) class Batch: diff --git a/gflownet/utils/crystals/build_lattice_dicts.py b/gflownet/utils/crystals/build_lattice_dicts.py index f62c4bd19..65d4a7958 100644 --- a/gflownet/utils/crystals/build_lattice_dicts.py +++ b/gflownet/utils/crystals/build_lattice_dicts.py @@ -8,12 +8,19 @@ import numpy as np import yaml -from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA, - CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS, - POINT_SYMMETRIES, - RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA) -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from lattice_constants import ( + CRYSTAL_CLASSES_WIKIPEDIA, + CRYSTAL_LATTICE_SYSTEMS, + CRYSTAL_SYSTEMS, + POINT_SYMMETRIES, + RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA, +) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) N_SPACE_GROUPS = 230 diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py index 273f2cdae..e7ac71308 100644 --- a/gflownet/utils/molecule/geom.py +++ b/gflownet/utils/molecule/geom.py @@ -8,8 +8,10 @@ from rdkit import Chem from tqdm import tqdm -from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, - is_hydrogen_ta) +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + is_hydrogen_ta, +) def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): diff --git a/gflownet/utils/oracle.py b/gflownet/utils/oracle.py index 1fc9c92c6..c4713745a 100644 --- a/gflownet/utils/oracle.py +++ b/gflownet/utils/oracle.py @@ -14,8 +14,7 @@ ) pass try: - from bbdob import (DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, - WModel) + from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel from bbdob.utils import idx2one_hot except: print( diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index 743b68256..b77bb2e89 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -8,6 +8,7 @@ from math import floor import gpytorch + # import tqdm import torch from botorch.test_functions import Hartmann diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index 8afde5dc8..b51df0ce6 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -6,6 +6,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Branin, Hartmann diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 76af6ff00..06c5a3ed6 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -3,6 +3,7 @@ import gpytorch import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -49,10 +50,8 @@ def forward(self, x): from botorch.models.utils import add_output_dim from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) -from gpytorch.likelihoods.gaussian_likelihood import \ - FixedNoiseGaussianLikelihood +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index 861268aeb..c4f7de6d0 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -2,6 +2,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -55,8 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal + # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index 42dcbb9b4..6320d4f05 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -2,6 +2,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -56,8 +57,7 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index 0b15c98be..d0664a342 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -3,14 +3,15 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.mlls import ExactMarginalLogLikelihood + # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 1d6626b33..2c75fd6a4 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -3,14 +3,15 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.mlls import ExactMarginalLogLikelihood + # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index 989af46c4..f712eaaf0 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -10,6 +10,7 @@ from math import floor import gpytorch + # import tqdm import torch from botorch.test_functions import Hartmann @@ -214,7 +215,9 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qLowerBoundMaxValueEntropy, qMaxValueEntropy) + qLowerBoundMaxValueEntropy, + qMaxValueEntropy, +) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index 5c6fa1699..5194b1068 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -9,10 +9,16 @@ from rdkit import Chem from tqdm import tqdm -from gflownet.utils.molecule.geom import (all_same_graphs, get_all_confs_geom, - get_conf_geom, get_rd_mol) -from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, - has_hydrogen_tas) +from gflownet.utils.molecule.geom import ( + all_same_graphs, + get_all_confs_geom, + get_conf_geom, + get_rd_mol, +) +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + has_hydrogen_tas, +) """ Here we use rdkit_folder format of the GEOM dataset diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 868c3745e..3df1c78c9 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -19,8 +19,7 @@ from collections import Counter -from external.repos.ActiveLearningMaterials.dave.utils.loaders import \ - make_loaders +from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 6ffcbaa21..62a226ae7 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -4,8 +4,12 @@ """ from argparse import ArgumentParser -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) from pyxtal.symmetry import Group N_SYMMETRY_GROUPS = 230 diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index d74ea29bd..16aea2814 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -2,13 +2,17 @@ import pytest import torch -from gflownet.envs.crystals.lattice_parameters import (CUBIC, HEXAGONAL, - LATTICE_SYSTEMS, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC, - LatticeParameters) +from gflownet.envs.crystals.lattice_parameters import ( + CUBIC, + HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, + LatticeParameters, +) @pytest.fixture() diff --git a/tests/gflownet/envs/test_tree.py b/tests/gflownet/envs/test_tree.py index 3d7af9c84..d21009288 100644 --- a/tests/gflownet/envs/test_tree.py +++ b/tests/gflownet/envs/test_tree.py @@ -5,8 +5,15 @@ import pytest import torch -from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator, - Stage, Status, Tree) +from gflownet.envs.tree import ( + ActionType, + Attribute, + NodeType, + Operator, + Stage, + Status, + Tree, +) from gflownet.utils.common import tfloat NAN = float("NaN") diff --git a/tests/gflownet/policy/test_multihead_tree_policy.py b/tests/gflownet/policy/test_multihead_tree_policy.py index 5d5448099..a28570a53 100644 --- a/tests/gflownet/policy/test_multihead_tree_policy.py +++ b/tests/gflownet/policy/test_multihead_tree_policy.py @@ -4,10 +4,13 @@ from torch_geometric.data import Batch from gflownet.envs.tree import Attribute, Operator, Tree -from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead, - LeafSelectionHead, - OperatorSelectionHead, - ThresholdSelectionHead) +from gflownet.policy.multihead_tree import ( + Backbone, + FeatureSelectionHead, + LeafSelectionHead, + OperatorSelectionHead, + ThresholdSelectionHead, +) N_OBSERVATIONS = 17 N_FEATURES = 5 diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index 931bb518b..b316dc1c0 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -2,8 +2,10 @@ from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import (find_rotor_from_smiles, - is_hydrogen_ta) +from gflownet.utils.molecule.rotatable_bonds import ( + find_rotor_from_smiles, + is_hydrogen_ta, +) def test_simple_ad(): diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index ef11ca543..acdeef1db 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -8,8 +8,7 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values -from gflownet.utils.molecule.torsions import (apply_rotations, - get_rotation_masks) +from gflownet.utils.molecule.torsions import apply_rotations, get_rotation_masks def test_four_nodes_chain(): @@ -148,8 +147,7 @@ def stress_test_apply_rotation_alanine_dipeptide(): from rdkit.Geometry.rdGeometry import Point3D from gflownet.utils.molecule.featurizer import MolDGLFeaturizer - from gflownet.utils.molecule.rdkit_conformer import \ - get_torsion_angles_values + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values mol = Chem.MolFromSmiles(constants.ad_smiles) mol = Chem.AddHs(mol) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index e26dd2455..338dfd061 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -8,9 +8,16 @@ from gflownet.proxy.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch -from gflownet.utils.common import (concat_items, copy, set_device, - set_float_precision, tbool, tfloat, tint, - tlong) +from gflownet.utils.common import ( + concat_items, + copy, + set_device, + set_float_precision, + tbool, + tfloat, + tint, + tlong, +) # Sets the number of repetitions for the tests. Please increase to ~10 after # introducing changes to the Batch class and decrease again to 1 when passed. From b9183a105099726e671ee88c56cae2eee39aeb73 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 11 Jan 2024 17:51:30 -0500 Subject: [PATCH 16/22] script for generating uniform samples with reward weights --- .../gen_uniform_samples_with_rewards.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 scripts/conformer/gen_uniform_samples_with_rewards.py diff --git a/scripts/conformer/gen_uniform_samples_with_rewards.py b/scripts/conformer/gen_uniform_samples_with_rewards.py new file mode 100644 index 000000000..b4801cf54 --- /dev/null +++ b/scripts/conformer/gen_uniform_samples_with_rewards.py @@ -0,0 +1,79 @@ + +# tblite import should stay here first! othervise everything fails with tblite errors +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy + +import os +import pandas as pd +import numpy as np +import pickle +import argparse +from pathlib import Path +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm import tqdm + +from rdkit.Chem import AllChem, rdMolTransforms +from gflownet.envs.conformers.conformer import PREDEFINED_SMILES +from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, find_rotor_from_smiles +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.envs.conformers.conformer import Conformer + + + +def get_uniform_samples_and_energy_weights(smiles, n_samples, energy_model='torchani'): + n_torsion_angles = len(find_rotor_from_smiles(smiles)) + uniform_tas = np.random.uniform(0., 2*np.pi, size=(n_samples, n_torsion_angles)) + env = Conformer(smiles=smiles, n_torsion_angles=n_torsion_angles, reward_func="boltzmann", reward_beta=32) + if energy_model == 'torchani': + proxy = TorchANIMoleculeEnergy(device='cpu', float_precision=32) + elif energy_model == 'tblite': + proxy = TBLiteMoleculeEnergy(device='cpu', float_precision=32, batch_size=417) + else: + raise NotImplementedError(f"No proxy availabe for {energy_model}, use one of ['torchani', 'tblite']") + proxy.setup(env) + def get_weights(batch): + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + return rewards.numpy() + weights = get_weights(uniform_tas) + ddict = {f'ta_{idx}': uniform_tas[:, idx] for idx in range(n_torsion_angles)} + ddict.update({'weights': weights}) + return pd.DataFrame(ddict) + + +def main(args): + output_root = Path(args.output_dir) + if not output_root.exists(): + os.mkdir(output_root) + + samples_root = Path('/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples') + selected_mols = pd.read_csv(samples_root / 'gfn_samples_2-12' / 'torchani_selected.csv') + result = dict() + for smiles in tqdm(selected_mols['SMILES'].values): + result.update({ + smiles: get_uniform_samples_and_energy_weights(smiles, args.n_samples, args.energy_model) + }) + if args.save_each_df: + sm_idx = PREDEFINED_SMILES.index(smiles) + filename = filename = output_root / f"{args.energy_model}_{sm_idx}_weighted_samples_selected_smiles.csv" + result[smiles].to_csv(filename) + filename = output_root / f'{args.energy_model}_weighted_samples_selected_smiles.pkl' + with open(filename, 'wb') as file: + pickle.dump(result, file) + print(f"saved results at {filename}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--n_samples", + type=int, + default=1000, + ) + parser.add_argument("--energy_model", type=str, choices=['torchani', 'tblite'], default='torchani') + parser.add_argument("--save_each_df", type=bool, default=False) + parser.add_argument("--output_dir", type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/uniform_samples_with_reward_weights") + args = parser.parse_args() + main(args) From 84138832c704bf0535393add6009b5f5e66b84ce Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 22 Jan 2024 16:27:52 -0500 Subject: [PATCH 17/22] finish mcmc baseline script --- scripts/conformer/mcmc_baseline.py | 215 +++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 scripts/conformer/mcmc_baseline.py diff --git a/scripts/conformer/mcmc_baseline.py b/scripts/conformer/mcmc_baseline.py new file mode 100644 index 000000000..048e966df --- /dev/null +++ b/scripts/conformer/mcmc_baseline.py @@ -0,0 +1,215 @@ +try: + from tblite import interface +except: + pass + +import argparse +from getdist.mcsamples import MCSamplesFromCobaya +import getdist.plots as gdplt + +from scipy import stats +from pathlib import Path +from cobaya.run import run +import pickle +import numpy as np +import os +import torch + +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.envs.conformers.conformer import Conformer + +def convert_to_numpy_if_needed(array): + if torch.is_tensor(array): + return array.cpu().detach().numpy() + else: + return array + + +def main(args): + if args.proxy_name == 'torchani': + proxy_class = TorchANIMoleculeEnergy + elif args.proxy_name == 'tblite': + proxy_class = TBLiteMoleculeEnergy + elif args.proxy_name == 'xtb': + proxy_class = XTBMoleculeEnergy + + # Leave as is + DEVICE = "cpu" + FLOAT_PRECISION = 32 + REWARD_FUNC = "boltzmann" + REWARD_BETA = 32 + + # output dir + output_dir = Path(args.output_dir) + if not output_dir.exists(): + os.mkdir(output_dir) + + for smile in args.ids: + # Change n_torsion_angles + env = Conformer( + smiles=int(smile), n_torsion_angles=-1, reward_func=REWARD_FUNC, reward_beta=REWARD_BETA + ) + + ndims = len(env.torsion_angles) + + proxy = proxy_class(device=DEVICE, float_precision=FLOAT_PRECISION) + proxy.setup(env) + + print(f"Sampling for {ndims} dimensions with {args.proxy_name} proxy") + + if ndims == 2: + def reward(p0, p1): + batch = np.concatenate([[p0], [p1]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 3: + def reward(p0, p1, p2): + batch = np.concatenate([[p0], [p1], [p2]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 4: + def reward(p0, p1, p2, p3): + batch = np.concatenate([[p0], [p1], [p2], [p3]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + if ndims == 5: + def reward(p0, p1, p2, p3, p4): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + if ndims == 6: + def reward(p0, p1, p2, p3, p4, p5): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 7: + def reward(p0, p1, p2, p3, p4, p5 ,p6): + batch = np.concatenate([[p0], [p1], [p2], [p2], [p3], [p4], [p5], [p6]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 8: + def reward(p0, p1, p2, p3, p4, p5 ,p6, p7): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 9: + def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 10: + def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 11: + def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + elif ndims == 12: + def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10, p11): + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10], [p11]]).reshape(1, -1) + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + rewards = env.proxy2reward(-energies) + rewards = convert_to_numpy_if_needed(rewards) + return np.log(rewards) + + + info = {"likelihood": {"reward": reward}} + + info["params"] = {f"p{i}": {"prior": {"min": 0, "max": 2*np.pi}, "ref": np.pi, "proposal": 0.01} for i in range(ndims)} + + Rminus1_stop = 0.05 + info["sampler"] = {"mcmc": {"Rminus1_stop": Rminus1_stop, "max_tries": 1000}} + + updated_info, sampler = run(info); + + gdsamples = MCSamplesFromCobaya(updated_info, sampler.products()["sample"]) + + def get_energy(batch): + batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) + + proxy_batch = env.statebatch2proxy(batch) + energies = proxy(env.statebatch2proxy(batch)) + return energies + + if gdsamples.samples.shape[0] >= 1000: + npars = len(info["params"]) + dct = {"x": gdsamples.samples[-1000:, :npars]} #, "energy": np.exp(gdsamples.loglikes[-10000:])} + + dct["energy"] = get_energy(gdsamples.samples[-1000:, :npars]) + + dct["conformer"] = [env.set_conformer(state).rdk_mol for state in gdsamples.samples[-1000:, :npars]] + pickle.dump( + dct, open(output_dir / f"conformers_mcmc_smile{smile}_{args.proxy_name}.pkl", "wb") + ) + print(f'Finished smile {smile} (dimensions {ndims})') + + else: + print(f'Not enough samples for smile {smile}') + return 0 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ids', nargs='+', required=True, type=int) + parser.add_argument('--output_dir', type=str, default='./mcmc_outputs/') + parser.add_argument('--proxy_name', type=str, default='torchani') + args = parser.parse_args() + main(args) \ No newline at end of file From 6adda4d2740ba1524195aa48bcb8aa6e4a8b8e95 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 23 Jan 2024 12:56:38 -0500 Subject: [PATCH 18/22] add rminus1_stop argument --- scripts/conformer/mcmc_baseline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/conformer/mcmc_baseline.py b/scripts/conformer/mcmc_baseline.py index 048e966df..af18a687e 100644 --- a/scripts/conformer/mcmc_baseline.py +++ b/scripts/conformer/mcmc_baseline.py @@ -175,7 +175,7 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10, p11): info["params"] = {f"p{i}": {"prior": {"min": 0, "max": 2*np.pi}, "ref": np.pi, "proposal": 0.01} for i in range(ndims)} - Rminus1_stop = 0.05 + Rminus1_stop = args.rminus1_stop info["sampler"] = {"mcmc": {"Rminus1_stop": Rminus1_stop, "max_tries": 1000}} updated_info, sampler = run(info); @@ -211,5 +211,6 @@ def get_energy(batch): parser.add_argument('--ids', nargs='+', required=True, type=int) parser.add_argument('--output_dir', type=str, default='./mcmc_outputs/') parser.add_argument('--proxy_name', type=str, default='torchani') + parser.add_argument('--rminus1_stop', type=float, default=0.05) args = parser.parse_args() main(args) \ No newline at end of file From 6dcb1cda186a76c4cf1b2699c562f453975950b8 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 23 Jan 2024 12:59:06 -0500 Subject: [PATCH 19/22] black, hacked isort --- .../gen_uniform_samples_with_rewards.py | 84 +++++--- scripts/conformer/mcmc_baseline.py | 186 ++++++++++++------ 2 files changed, 181 insertions(+), 89 deletions(-) diff --git a/scripts/conformer/gen_uniform_samples_with_rewards.py b/scripts/conformer/gen_uniform_samples_with_rewards.py index b4801cf54..3b8ad6845 100644 --- a/scripts/conformer/gen_uniform_samples_with_rewards.py +++ b/scripts/conformer/gen_uniform_samples_with_rewards.py @@ -1,4 +1,3 @@ - # tblite import should stay here first! othervise everything fails with tblite errors from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy @@ -14,31 +13,42 @@ from rdkit.Chem import AllChem, rdMolTransforms from gflownet.envs.conformers.conformer import PREDEFINED_SMILES -from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, find_rotor_from_smiles +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + find_rotor_from_smiles, +) from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy from gflownet.envs.conformers.conformer import Conformer - -def get_uniform_samples_and_energy_weights(smiles, n_samples, energy_model='torchani'): +def get_uniform_samples_and_energy_weights(smiles, n_samples, energy_model="torchani"): n_torsion_angles = len(find_rotor_from_smiles(smiles)) - uniform_tas = np.random.uniform(0., 2*np.pi, size=(n_samples, n_torsion_angles)) - env = Conformer(smiles=smiles, n_torsion_angles=n_torsion_angles, reward_func="boltzmann", reward_beta=32) - if energy_model == 'torchani': - proxy = TorchANIMoleculeEnergy(device='cpu', float_precision=32) - elif energy_model == 'tblite': - proxy = TBLiteMoleculeEnergy(device='cpu', float_precision=32, batch_size=417) + uniform_tas = np.random.uniform(0.0, 2 * np.pi, size=(n_samples, n_torsion_angles)) + env = Conformer( + smiles=smiles, + n_torsion_angles=n_torsion_angles, + reward_func="boltzmann", + reward_beta=32, + ) + if energy_model == "torchani": + proxy = TorchANIMoleculeEnergy(device="cpu", float_precision=32) + elif energy_model == "tblite": + proxy = TBLiteMoleculeEnergy(device="cpu", float_precision=32, batch_size=417) else: - raise NotImplementedError(f"No proxy availabe for {energy_model}, use one of ['torchani', 'tblite']") + raise NotImplementedError( + f"No proxy availabe for {energy_model}, use one of ['torchani', 'tblite']" + ) proxy.setup(env) + def get_weights(batch): batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) energies = proxy(env.statebatch2proxy(batch)) rewards = env.proxy2reward(-energies) return rewards.numpy() + weights = get_weights(uniform_tas) - ddict = {f'ta_{idx}': uniform_tas[:, idx] for idx in range(n_torsion_angles)} - ddict.update({'weights': weights}) + ddict = {f"ta_{idx}": uniform_tas[:, idx] for idx in range(n_torsion_angles)} + ddict.update({"weights": weights}) return pd.DataFrame(ddict) @@ -46,20 +56,31 @@ def main(args): output_root = Path(args.output_dir) if not output_root.exists(): os.mkdir(output_root) - - samples_root = Path('/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples') - selected_mols = pd.read_csv(samples_root / 'gfn_samples_2-12' / 'torchani_selected.csv') + + samples_root = Path( + "/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples" + ) + selected_mols = pd.read_csv( + samples_root / "gfn_samples_2-12" / "torchani_selected.csv" + ) result = dict() - for smiles in tqdm(selected_mols['SMILES'].values): - result.update({ - smiles: get_uniform_samples_and_energy_weights(smiles, args.n_samples, args.energy_model) - }) - if args.save_each_df: - sm_idx = PREDEFINED_SMILES.index(smiles) - filename = filename = output_root / f"{args.energy_model}_{sm_idx}_weighted_samples_selected_smiles.csv" - result[smiles].to_csv(filename) - filename = output_root / f'{args.energy_model}_weighted_samples_selected_smiles.pkl' - with open(filename, 'wb') as file: + for smiles in tqdm(selected_mols["SMILES"].values): + result.update( + { + smiles: get_uniform_samples_and_energy_weights( + smiles, args.n_samples, args.energy_model + ) + } + ) + if args.save_each_df: + sm_idx = PREDEFINED_SMILES.index(smiles) + filename = filename = ( + output_root + / f"{args.energy_model}_{sm_idx}_weighted_samples_selected_smiles.csv" + ) + result[smiles].to_csv(filename) + filename = output_root / f"{args.energy_model}_weighted_samples_selected_smiles.pkl" + with open(filename, "wb") as file: pickle.dump(result, file) print(f"saved results at {filename}") @@ -71,9 +92,14 @@ def main(args): type=int, default=1000, ) - parser.add_argument("--energy_model", type=str, choices=['torchani', 'tblite'], default='torchani') + parser.add_argument( + "--energy_model", type=str, choices=["torchani", "tblite"], default="torchani" + ) parser.add_argument("--save_each_df", type=bool, default=False) - parser.add_argument("--output_dir", type=str, - default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/uniform_samples_with_reward_weights") + parser.add_argument( + "--output_dir", + type=str, + default="/home/mila/a/alexandra.volokhova/projects/gflownet/results/conformer/samples/uniform_samples_with_reward_weights", + ) args = parser.parse_args() main(args) diff --git a/scripts/conformer/mcmc_baseline.py b/scripts/conformer/mcmc_baseline.py index af18a687e..224641ded 100644 --- a/scripts/conformer/mcmc_baseline.py +++ b/scripts/conformer/mcmc_baseline.py @@ -4,21 +4,21 @@ pass import argparse -from getdist.mcsamples import MCSamplesFromCobaya -import getdist.plots as gdplt - -from scipy import stats -from pathlib import Path -from cobaya.run import run +import os import pickle +from pathlib import Path + +import getdist.plots as gdplt import numpy as np -import os import torch - -from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy -from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy -from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from cobaya.run import run +from getdist.mcsamples import MCSamplesFromCobaya from gflownet.envs.conformers.conformer import Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.xtb import XTBMoleculeEnergy +from scipy import stats + def convert_to_numpy_if_needed(array): if torch.is_tensor(array): @@ -28,12 +28,12 @@ def convert_to_numpy_if_needed(array): def main(args): - if args.proxy_name == 'torchani': + if args.proxy_name == "torchani": proxy_class = TorchANIMoleculeEnergy - elif args.proxy_name == 'tblite': + elif args.proxy_name == "tblite": proxy_class = TBLiteMoleculeEnergy - elif args.proxy_name == 'xtb': - proxy_class = XTBMoleculeEnergy + elif args.proxy_name == "xtb": + proxy_class = XTBMoleculeEnergy # Leave as is DEVICE = "cpu" @@ -49,17 +49,21 @@ def main(args): for smile in args.ids: # Change n_torsion_angles env = Conformer( - smiles=int(smile), n_torsion_angles=-1, reward_func=REWARD_FUNC, reward_beta=REWARD_BETA + smiles=int(smile), + n_torsion_angles=-1, + reward_func=REWARD_FUNC, + reward_beta=REWARD_BETA, ) - + ndims = len(env.torsion_angles) - + proxy = proxy_class(device=DEVICE, float_precision=FLOAT_PRECISION) proxy.setup(env) print(f"Sampling for {ndims} dimensions with {args.proxy_name} proxy") - + if ndims == 2: + def reward(p0, p1): batch = np.concatenate([[p0], [p1]]).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) @@ -69,7 +73,9 @@ def reward(p0, p1): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 3: + + elif ndims == 3: + def reward(p0, p1, p2): batch = np.concatenate([[p0], [p1], [p2]]).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) @@ -79,7 +85,9 @@ def reward(p0, p1, p2): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 4: + + elif ndims == 4: + def reward(p0, p1, p2, p3): batch = np.concatenate([[p0], [p1], [p2], [p3]]).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) @@ -89,7 +97,9 @@ def reward(p0, p1, p2, p3): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) + if ndims == 5: + def reward(p0, p1, p2, p3, p4): batch = np.concatenate([[p0], [p1], [p2], [p3], [p4]]).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) @@ -99,9 +109,13 @@ def reward(p0, p1, p2, p3, p4): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) + if ndims == 6: + def reward(p0, p1, p2, p3, p4, p5): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5]]).reshape(1, -1) + batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5]]).reshape( + 1, -1 + ) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -109,9 +123,13 @@ def reward(p0, p1, p2, p3, p4, p5): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 7: - def reward(p0, p1, p2, p3, p4, p5 ,p6): - batch = np.concatenate([[p0], [p1], [p2], [p2], [p3], [p4], [p5], [p6]]).reshape(1, -1) + + elif ndims == 7: + + def reward(p0, p1, p2, p3, p4, p5, p6): + batch = np.concatenate( + [[p0], [p1], [p2], [p2], [p3], [p4], [p5], [p6]] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -119,9 +137,13 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 8: - def reward(p0, p1, p2, p3, p4, p5 ,p6, p7): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7]]).reshape(1, -1) + + elif ndims == 8: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7]] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -129,9 +151,13 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 9: - def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8]]).reshape(1, -1) + + elif ndims == 9: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8]] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -139,9 +165,13 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 10: - def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9]]).reshape(1, -1) + + elif ndims == 10: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9]] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -149,9 +179,13 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 11: - def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10]]).reshape(1, -1) + + elif ndims == 11: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10): + batch = np.concatenate( + [[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10]] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -159,9 +193,26 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - elif ndims == 12: - def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10, p11): - batch = np.concatenate([[p0], [p1], [p2], [p3], [p4], [p5], [p6], [p7], [p8], [p9], [p10], [p11]]).reshape(1, -1) + + elif ndims == 12: + + def reward(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11): + batch = np.concatenate( + [ + [p0], + [p1], + [p2], + [p3], + [p4], + [p5], + [p6], + [p7], + [p8], + [p9], + [p10], + [p11], + ] + ).reshape(1, -1) batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) @@ -169,48 +220,63 @@ def reward(p0, p1, p2, p3, p4, p5 ,p6, p7, p8, p9, p10, p11): rewards = env.proxy2reward(-energies) rewards = convert_to_numpy_if_needed(rewards) return np.log(rewards) - - + info = {"likelihood": {"reward": reward}} - info["params"] = {f"p{i}": {"prior": {"min": 0, "max": 2*np.pi}, "ref": np.pi, "proposal": 0.01} for i in range(ndims)} - + info["params"] = { + f"p{i}": { + "prior": {"min": 0, "max": 2 * np.pi}, + "ref": np.pi, + "proposal": 0.01, + } + for i in range(ndims) + } + Rminus1_stop = args.rminus1_stop info["sampler"] = {"mcmc": {"Rminus1_stop": Rminus1_stop, "max_tries": 1000}} - updated_info, sampler = run(info); + updated_info, sampler = run(info) gdsamples = MCSamplesFromCobaya(updated_info, sampler.products()["sample"]) - + def get_energy(batch): batch = np.concatenate([batch, np.zeros((batch.shape[0], 1))], axis=-1) proxy_batch = env.statebatch2proxy(batch) energies = proxy(env.statebatch2proxy(batch)) return energies - + if gdsamples.samples.shape[0] >= 1000: npars = len(info["params"]) - dct = {"x": gdsamples.samples[-1000:, :npars]} #, "energy": np.exp(gdsamples.loglikes[-10000:])} + dct = { + "x": gdsamples.samples[-1000:, :npars] + } # , "energy": np.exp(gdsamples.loglikes[-10000:])} dct["energy"] = get_energy(gdsamples.samples[-1000:, :npars]) - dct["conformer"] = [env.set_conformer(state).rdk_mol for state in gdsamples.samples[-1000:, :npars]] + dct["conformer"] = [ + env.set_conformer(state).rdk_mol + for state in gdsamples.samples[-1000:, :npars] + ] pickle.dump( - dct, open(output_dir / f"conformers_mcmc_smile{smile}_{args.proxy_name}.pkl", "wb") + dct, + open( + output_dir / f"conformers_mcmc_smile{smile}_{args.proxy_name}.pkl", + "wb", + ), ) - print(f'Finished smile {smile} (dimensions {ndims})') - - else: - print(f'Not enough samples for smile {smile}') + print(f"Finished smile {smile} (dimensions {ndims})") + + else: + print(f"Not enough samples for smile {smile}") return 0 -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--ids', nargs='+', required=True, type=int) - parser.add_argument('--output_dir', type=str, default='./mcmc_outputs/') - parser.add_argument('--proxy_name', type=str, default='torchani') - parser.add_argument('--rminus1_stop', type=float, default=0.05) + parser.add_argument("--ids", nargs="+", required=True, type=int) + parser.add_argument("--output_dir", type=str, default="./mcmc_outputs/") + parser.add_argument("--proxy_name", type=str, default="torchani") + parser.add_argument("--rminus1_stop", type=float, default=0.05) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) From c70c392079554cbea5d2a045c46567c88ad325f3 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 23 Jan 2024 13:00:08 -0500 Subject: [PATCH 20/22] hacked isort --- .../gen_uniform_samples_with_rewards.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/scripts/conformer/gen_uniform_samples_with_rewards.py b/scripts/conformer/gen_uniform_samples_with_rewards.py index 3b8ad6845..1235433a2 100644 --- a/scripts/conformer/gen_uniform_samples_with_rewards.py +++ b/scripts/conformer/gen_uniform_samples_with_rewards.py @@ -1,24 +1,22 @@ # tblite import should stay here first! othervise everything fails with tblite errors -from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy - +import argparse import os -import pandas as pd -import numpy as np import pickle -import argparse from pathlib import Path -import seaborn as sns -import matplotlib.pyplot as plt -from tqdm import tqdm -from rdkit.Chem import AllChem, rdMolTransforms -from gflownet.envs.conformers.conformer import PREDEFINED_SMILES +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from gflownet.envs.conformers.conformer import PREDEFINED_SMILES, Conformer +from gflownet.proxy.conformers.tblite import TBLiteMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy from gflownet.utils.molecule.rotatable_bonds import ( - get_rotatable_ta_list, find_rotor_from_smiles, + get_rotatable_ta_list, ) -from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy -from gflownet.envs.conformers.conformer import Conformer +from rdkit.Chem import AllChem, rdMolTransforms +from tqdm import tqdm def get_uniform_samples_and_energy_weights(smiles, n_samples, energy_model="torchani"): From 43d9e82d5e8ed7be8a53405ca483cfd24260ba51 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 23 Jan 2024 16:09:14 -0500 Subject: [PATCH 21/22] change number of min samples to 1100 --- scripts/conformer/mcmc_baseline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/conformer/mcmc_baseline.py b/scripts/conformer/mcmc_baseline.py index 224641ded..db3542a03 100644 --- a/scripts/conformer/mcmc_baseline.py +++ b/scripts/conformer/mcmc_baseline.py @@ -246,7 +246,7 @@ def get_energy(batch): energies = proxy(env.statebatch2proxy(batch)) return energies - if gdsamples.samples.shape[0] >= 1000: + if gdsamples.samples.shape[0] >= 1100: npars = len(info["params"]) dct = { "x": gdsamples.samples[-1000:, :npars] From 8fb1f824f2032cbcb784740ed56f673abc5d7749 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 30 Jan 2024 16:00:54 -0500 Subject: [PATCH 22/22] updated smiles parsing --- scripts/conformer/compute_metrics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/conformer/compute_metrics.py b/scripts/conformer/compute_metrics.py index 4b243c811..df6eb8993 100644 --- a/scripts/conformer/compute_metrics.py +++ b/scripts/conformer/compute_metrics.py @@ -12,6 +12,7 @@ from gflownet.utils.molecule.geom import get_all_confs_geom from gflownet.utils.molecule.metrics import get_best_rmsd, get_cov_mat +from gflownet.envs.conformers.conformer import PREDEFINED_SMILES def distant_enough(conf, others, delta): @@ -42,6 +43,9 @@ def get_diverse_top_k(confs_list, k=None, delta=1.25): def get_smiles_from_filename(filename): smiles = filename.split("_")[1] + if smiles == 'mcmc': + smiles_idx = int(filename.split("_")[2][5:]) + smiles = PREDEFINED_SMILES[smiles_idx] if smiles.endswith(".pkl"): smiles = smiles[:-4] return smiles @@ -174,7 +178,7 @@ def main(args): parser.add_argument( "--geom_stats", type=str, - default="/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/geom_stats.csv", + default="/home/mila/a/alexandra.volokhova/projects/gflownet/scripts/conformer/generated_files/geom_stats.csv", ) parser.add_argument("--gen_dir", type=str, default="./") parser.add_argument("--use_top_k", type=bool, default=False)