Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acegen rl #56

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
384e8b9
denovo file
MorganCThomas May 30, 2024
3b7a76a
Merge branch 'main' into morgan_dev
MorganCThomas Jun 12, 2024
2e6fb69
more intuitive log directory
MorganCThomas Jun 12, 2024
f28df5f
update to new molscore curriciulum api
MorganCThomas Jun 13, 2024
82d78bf
merge main
MorganCThomas Jun 13, 2024
d8a61ca
align carl with main changes
MorganCThomas Jun 13, 2024
aa2ea89
change default config
MorganCThomas Jun 19, 2024
f5a99e4
Merge branch 'main' into carl
MorganCThomas Jun 19, 2024
8e4c3e1
initial
MorganCThomas Jun 20, 2024
181da53
simple chist implementation
MorganCThomas Jun 20, 2024
6e227cf
config
MorganCThomas Jun 20, 2024
80e8b04
modify defaults, change param reading
MorganCThomas Jun 25, 2024
3b1a1c7
update
MorganCThomas Jul 2, 2024
289a979
update gitignore
MorganCThomas Jul 2, 2024
69f9d59
add guacamol prior
MorganCThomas Jul 2, 2024
a56f159
save runtime
MorganCThomas Jul 2, 2024
ce89393
Merge branch 'main' into carl
MorganCThomas Jul 2, 2024
79209a9
update gitignore
MorganCThomas Jul 2, 2024
a77a10f
negative log-likelihood makes no diff -> simplify
MorganCThomas Jul 2, 2024
8d6aaab
Merge branch 'main' into multi-agent
albertbou92 Jul 2, 2024
c407e1c
update
MorganCThomas Jul 3, 2024
a635f6b
typo
MorganCThomas Jul 3, 2024
3a30918
fix
albertbou92 Jul 3, 2024
9bb8ff4
scalable script
albertbou92 Jul 3, 2024
1c6c953
scalable script 2
albertbou92 Jul 3, 2024
b31051f
scalable script 2
albertbou92 Jul 3, 2024
df78696
scalable script reward shaping
albertbou92 Jul 4, 2024
cc1e417
scalable script reward shaping
albertbou92 Jul 4, 2024
b624672
config default
albertbou92 Jul 4, 2024
d6c3451
config default
albertbou92 Jul 4, 2024
ceb6627
fix
albertbou92 Jul 4, 2024
7b03fc3
fix
albertbou92 Jul 4, 2024
d79d66a
fix
albertbou92 Jul 4, 2024
6c908ce
configs
MorganCThomas Jul 4, 2024
2ca0f09
merge
MorganCThomas Jul 4, 2024
1f5e0bc
add chembl34 prior
MorganCThomas Jul 10, 2024
fb72abc
Merge branch 'multi-agent' into carl
MorganCThomas Jul 11, 2024
2f254e1
bug fix
MorganCThomas Jul 24, 2024
d53c5f2
Merge branch 'main' into carl
MorganCThomas Jul 24, 2024
662441e
merge dev
MorganCThomas Jul 24, 2024
b89c09a
add MolScore option to score invalids
MorganCThomas Aug 26, 2024
1c46535
Merge branch 'main' into carl
MorganCThomas Aug 26, 2024
7a1f7ea
write correct seed ; explicit warning in collate smiles
MorganCThomas Oct 22, 2024
57148bd
always allow kl_coef ; write molscore scores intermediate
MorganCThomas Dec 13, 2024
44afb57
no need to do molscore config saving
MorganCThomas Dec 13, 2024
8b54f5a
move baseline code; add RND
MorganCThomas Dec 18, 2024
bce081d
fix entropy; add rnd coefficient
MorganCThomas Dec 31, 2024
f2c7c6e
add screening baseline
MorganCThomas Jan 7, 2025
469e195
add screening baseline
MorganCThomas Jan 7, 2025
21cabce
move baseline code to rl_env
MorganCThomas Jan 20, 2025
9aebaa5
refactor; test; lint
MorganCThomas Jan 28, 2025
ec8e3aa
switch default config
MorganCThomas Jan 28, 2025
9da59e1
refactor
MorganCThomas Jan 28, 2025
83e45ab
refactor; address future warnings
MorganCThomas Jan 28, 2025
e828c63
remove unused prior
MorganCThomas Jan 28, 2025
c0ab7ec
remove deprecated function
MorganCThomas Jan 28, 2025
2d451f1
remove patience
MorganCThomas Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,12 @@ cython_debug/
results/
logs_*

# Configs unless specified
*.yaml

# Nohup files
*.out
*.nohup

*.DS_Store*
*gpt2_enamine_real.ckpt*
20 changes: 20 additions & 0 deletions acegen/data/chem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

import torch

from rdkit.Chem import AllChem as Chem, Draw


Expand Down Expand Up @@ -60,3 +62,21 @@ def draw(mol_list, molsPerRow=5, subImgSize=(300, 300)):
image = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize)

return image


def get_fp(mol):
"""Create a Circular/Path based fingerprint from a SMILES string or RDKitMol."""
mol = get_mol(mol)
if mol:
ecfp = Chem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=256)
rdk = Chem.RDKFingerprint(mol, maxPath=6, fpSize=256, nBitsPerHash=2)
fp = torch.cat([torch.tensor(ecfp), torch.tensor(rdk)])
return fp
else:
return torch.zeros((512), dtype=torch.int64)


def get_fp_hist(mols):
"""Compute the histogram of fingerprints from a list of SMILES strings or Mols."""
fp_hist = torch.vstack([get_fp(mol) for mol in mols]).sum(0)
return fp_hist
16 changes: 13 additions & 3 deletions acegen/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

import torch
from tensordict import TensorDict

Expand Down Expand Up @@ -75,10 +77,18 @@ def collate_smiles_to_tensordict(
"""Function to take a list of encoded sequences and turn them into a tensordict."""
collated_arr = torch.ones(len(arr), max_length) * -1
for i, seq in enumerate(arr):
collated_arr[i, : seq.size(0)] = seq
if seq.size(0) > max_length:
warnings.warn(
f"Sequence {i} is longer than max_length. Truncating to {max_length}."
)
collated_arr[i, :max_length] = seq[:max_length]
else:
collated_arr[i, : seq.size(0)] = seq
data = smiles_to_tensordict(
collated_arr, reward=reward, replace_mask_value=0, device=device
)
data.set("sequence", data.get("observation"))
data.set("sequence_mask", data.get("mask"))
data.set("sequence", data.get("observation").clone())
data.set("sequence_mask", data.get("mask").clone())
data.set(("next", "sequence"), data.get("next", "observation").clone())
data.set(("next", "sequence_mask"), data.get("next", "mask").clone())
return data
20 changes: 20 additions & 0 deletions acegen/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import tarfile
from functools import partial
from importlib import import_module, resources
from pathlib import Path

Expand Down Expand Up @@ -67,6 +68,14 @@ def extract(path):
resources.files("acegen.priors") / "gru_chembl_filtered.ckpt",
SMILESTokenizerChEMBL(),
),
"gru_chembl34": (
create_gru_actor,
create_gru_critic,
create_gru_actor_critic,
resources.files("acegen.priors") / "gru_chembl34_vocabulary.ckpt",
resources.files("acegen.priors") / "gru_chembl34.ckpt",
SMILESTokenizerChEMBL(),
),
"lstm": (
create_lstm_actor,
create_lstm_critic,
Expand All @@ -75,6 +84,17 @@ def extract(path):
resources.files("acegen.priors") / "lstm_chembl.ckpt",
SMILESTokenizerChEMBL(),
),
"lstm_guacamol": (
partial(create_lstm_actor, embedding_size=1024, hidden_size=1024, dropout=0.2),
partial(create_lstm_critic, embedding_size=1024, hidden_size=1024, dropout=0.2),
partial(
create_lstm_actor_critic, embedding_size=1024, hidden_size=1024, dropout=0.2
),
resources.files("acegen.priors") / "lstm_guacamol_vocabulary.txt",
resources.files("acegen.priors")
/ "lstm_guacamol_model_final_0.473_acegen.ckpt",
SMILESTokenizerGuacaMol(),
),
"gpt2": (
create_gpt2_actor,
create_gpt2_critic,
Expand Down
16 changes: 16 additions & 0 deletions acegen/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import warnings
from typing import Union

import torch
from tensordict.nn import TensorDictModule


def adapt_state_dict(source_state_dict: dict, target_state_dict: dict):
Expand Down Expand Up @@ -32,3 +36,15 @@ def adapt_state_dict(source_state_dict: dict, target_state_dict: dict):
target_state_dict[key_target] = value_source

return target_state_dict


def reinitialize_model(
model: Union[torch.nn.Module, TensorDictModule], seed: int = 101
):
"""Random initialization of a models parameters."""
torch.manual_seed(seed)
for p in model.parameters():
if len(p.shape) == 1:
torch.nn.init.constant_(p, 0)
else:
torch.nn.init.uniform_(p)
Binary file added acegen/priors/gru_chembl34.ckpt
Binary file not shown.
Binary file added acegen/priors/gru_chembl34_vocabulary.ckpt
Binary file not shown.
38 changes: 38 additions & 0 deletions acegen/rl_env/baselines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch


class MovingAverageBaseline:
"""Class to keep track on the running mean and variance of tensors batches."""

def __init__(self, epsilon=1e-3, shape=(), device=torch.device("cpu")):
self.mean = torch.zeros(shape, dtype=torch.float64).to(device)
self.std = torch.zeros(shape, dtype=torch.float64).to(device)
self.count = epsilon

def update(self, x):
batch_mean = torch.mean(x, dim=0)
batch_std = torch.std(x, dim=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_std, batch_count)

def update_from_moments(self, batch_mean, batch_std, batch_count):
delta = batch_mean - self.mean
std_delta = batch_std - self.std
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
new_std = self.std + std_delta * batch_count / tot_count
new_count = tot_count
self.mean, self.std, self.count = new_mean, new_std, new_count


class LeaveOneOutBaseline:
"""Class to compute the leave-one-out baseline for a given tensor."""

def __init__(self):
self.mean = None

def update(self, x):
with torch.no_grad():
loo = x.unsqueeze(1).expand(-1, x.size(0))
loo_mask = 1 - torch.eye(loo.size(0), device=loo.device)
self.mean = (loo * loo_mask).sum(0) / loo_mask.sum(0)
26 changes: 13 additions & 13 deletions acegen/rl_env/token_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
Composite,
Categorical,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
Unbounded,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import EnvBase
Expand Down Expand Up @@ -183,9 +183,9 @@ def _set_specs(self) -> None:
obs_spec = (
OneHotDiscreteTensorSpec
if self.one_hot_obs_encoding
else DiscreteTensorSpec
else Categorical
)
self.observation_spec = CompositeSpec(
self.observation_spec = Composite(
{
"observation": obs_spec(
n=self.length_vocabulary,
Expand Down Expand Up @@ -222,9 +222,9 @@ def _set_specs(self) -> None:
action_spec = (
OneHotDiscreteTensorSpec
if self.one_hot_action_encoding
else DiscreteTensorSpec
else Categorical
)
self.action_spec = CompositeSpec(
self.action_spec = Composite(
{
"action": action_spec(
n=self.length_vocabulary,
Expand All @@ -233,9 +233,9 @@ def _set_specs(self) -> None:
)
}
).expand(self.num_envs)
self.reward_spec = CompositeSpec(
self.reward_spec = Composite(
{
"reward": UnboundedContinuousTensorSpec(
"reward": Unbounded(
shape=(1,),
dtype=torch.float32,
device=self.device,
Expand All @@ -244,15 +244,15 @@ def _set_specs(self) -> None:
).expand(self.num_envs)

self.done_spec = (
CompositeSpec(
Composite(
{
"done": DiscreteTensorSpec(
"done": Categorical(
n=2, dtype=torch.bool, device=self.device
),
"truncated": DiscreteTensorSpec(
"truncated": Categorical(
n=2, dtype=torch.bool, device=self.device
),
"terminated": DiscreteTensorSpec(
"terminated": Categorical(
n=2, dtype=torch.bool, device=self.device
),
}
Expand Down
29 changes: 15 additions & 14 deletions acegen/vocabulary/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,22 @@ def __init__(self, start_token="GO", end_token="EOS"):
self.REGEXP_ORDER = ["brackets", "brcl"]
self.start_token = start_token
self.end_token = end_token
self.encode_dict = {
"Br": "Y",
"Cl": "X",
"Si": "A",
"Se": "Z",
"@@": "R",
"se": "E",
}
self.decode_dict = {v: k for k, v in self.encode_dict.items()}

def tokenize(self, data, with_begin_and_end=False):
"""Tokenizes a SMILES string."""
for symbol, token in self.encode_dict.items():
data = data.replace(symbol, token)

def split_by(data, regexps):
if not regexps:
return list(data)
regexp = self.REGEXPS[regexps[0]]
splitted = regexp.split(data)
tokens = []
for i, split in enumerate(splitted):
if i % 2 == 0:
tokens += split_by(split, regexps[1:])
else:
tokens.append(split)
return tokens

tokens = split_by(data, self.REGEXP_ORDER)
tokens = list(data)
if with_begin_and_end:
tokens = [self.start_token] + tokens + [self.end_token]
return tokens
Expand All @@ -127,6 +125,9 @@ def untokenize(self, tokens, **kwargs):
break
if token != self.start_token:
smi += token

for symbol, token in self.decode_dict.items():
smi = smi.replace(symbol, token)
return smi


Expand Down
Loading
Loading