Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aizynthfinder/aizynthfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def tree_search(self, show_progress: bool = False) -> float:

try:
is_solved = self.tree.one_iteration()
self.config.search.algorithm_config["current_iteration"] += 1
except StopIteration:
break

Expand Down
4 changes: 3 additions & 1 deletion aizynthfinder/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class _SearchConfiguration:
"immediate_instantiation": (),
"mcts_grouping": None,
"search_rewards_weights": [],
"enhancement": "Default",
"current_iteration": 0,
}
)
max_transforms: int = 6
max_transforms: int = 10
iteration_limit: int = 100
time_limit: int = 120
return_first: bool = False
Expand Down
51 changes: 48 additions & 3 deletions aizynthfinder/search/mcts/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.is_expanded: bool = False
self.is_expandable: bool = not self.state.is_terminal
self._parent = parent
self.depth = 0

if owner is None:
self.created_at_iteration: Optional[int] = None
Expand Down Expand Up @@ -249,13 +250,34 @@ def expand(self) -> None:
for child in self.parent.children:
if child is not self:
cache_molecules.extend(child.state.expandable_mols)

if self.parent:
self.depth = self.parent.depth + 1
if not self.parent:
self.depth = 0
# Calculate the possible actions, fill the child_info lists
# Actions by default only assumes 1 set of reactants
actions, priors = self._expansion_policy(
self.state.expandable_mols, cache_molecules
)
self._fill_children_lists(actions, priors)
enhancement = self._config.search.algorithm_config["enhancement"]
n = 0
if enhancement == "Default":
self._fill_children_lists(actions, priors)
elif enhancement == "eUCT":
self._fill_children_lists(actions, priors)
print('through here')
elif enhancement == "dUCT-v1" or "dUCT-v2":
if enhancement == "dUCT-v1":
n = 20
elif enhancement == "dUCT-v2":
n = 50
action_prior_pairs = list(zip(actions, priors))
sorted_action_prior_pairs = sorted(action_prior_pairs, key=lambda x: x[1], reverse=True)
top_n_action_prior_pairs = sorted_action_prior_pairs[:n]
actions[:], priors[:] = zip(*top_n_action_prior_pairs)
actions = list(actions)
priors = list(priors)
self._fill_children_lists(actions, priors)

# Reverse the expansion if it did not produce any children
if len(actions) == 0:
Expand Down Expand Up @@ -377,6 +399,23 @@ def _children_u(self) -> np.ndarray:
child_visits = np.array(self._children_visitations)
return self._algo_config["C"] * np.sqrt(2 * total_visits / child_visits)

def _children_du(self, depth) -> np.ndarray:
total_visits = np.log(np.sum(self._children_visitations))
child_visits = np.array(self._children_visitations)
dc = self._algo_config["C"] + 0.5 * depth
return dc * np.sqrt(2 * total_visits / child_visits)

def _children_eu(self) -> np.ndarray:
total_visits = np.log(np.sum(self._children_visitations))
child_visits = np.array(self._children_visitations)
current_iteration = self._algo_config["current_iteration"]
if current_iteration == 0:
phi = (1 / self._config.search.iteration_limit)/2
else:
phi = (current_iteration/self._config.search.iteration_limit)/2
lamda = (self._algo_config["C"]/(phi+1))
return lamda * np.sqrt(2 * total_visits / child_visits)

def _create_children_nodes(
self, states: List[MctsState], child_idx: int
) -> List["MctsNode"]:
Expand Down Expand Up @@ -536,7 +575,13 @@ def _regenerated_blacklisted(self, reaction: RetroReaction) -> bool:
def _score_and_select(self) -> Optional["MctsNode"]:
if not max(self._children_values) > 0:
raise ValueError("Has no selectable children")
scores = self._children_q() + self._children_u()

if self._algo_config["enhancement"] == "dUCT-v1" or "dUCT-v2":
scores = self._children_q() + self._children_du(self.depth)
elif self._algo_config["enhancement"] == "eUCT":
scores = self._children_q() + self._children_eu()
elif self._algo_config["enhancement"] == "Default":
scores = self._children_q() + self._children_u()
indices = np.where(scores == scores.max())[0]
index = np.random.choice(indices)
return self._select_child(index)
Expand Down
Binary file not shown.
41 changes: 23 additions & 18 deletions sources/retrosynthesis/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,34 @@ def service_check():
def retrosynthesis():
access_key = str(request.args.get('key'))
if access_key != ACCESS_KEY:
print("invalid key")
return json.dumps({'Message': 'invalid key', 'Timestamp': time.time()})
print("Invalid key")
return json.dumps({'Message': 'Invalid key', 'Timestamp': time.time()})

smiles = str(request.args.get('smiles'))
solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles)
enhancement = str(request.args.get('enhancement', 'Default'))
finder = sources.retrosynthesis.startup.make_config()
finder.config.search.algorithm_config["enhancement"] = enhancement
solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles, finder)
page_data = {'Message': solved_route_dict_ls, 'Raw_Routes': raw_routes, 'Timestamp': time.time()}
json_dump = json.dumps(page_data)
return json_dump


def retrosynthesis_process(smiles):
"""
Takes a smiles string and returns a list of retrosynthetic routes stored as dictionaries
def retrosynthesis_process(smiles, finder):
"""
# load config containing policy file locations
print(smiles)
Takes a SMILES string and a pre-configured finder object and returns a list of retrosynthetic routes as dictionaries.
"""
print(f"Running retrosynthesis for SMILES: {smiles}")

from rdkit import Chem
from aizynthfinder.interfaces import aizynthcli
from sources.retrosynthesis.classes import RetroRoute
# from sources.retrosynthesis.startup import finder
finder = sources.retrosynthesis.startup.make_config()

mol = Chem.MolFromSmiles(smiles)
print(mol)
if not mol:
raise ValueError("Invalid SMILES string")
print(f"Molecule generated: {mol}")
aizynthcli._process_single_smiles(smiles, finder, None, False, None, [], None)
# Find solved routes and process routes objects into list of dictionaries
routes = finder.routes
solved_routes = []
for idx, node in enumerate(routes.nodes):
Expand All @@ -53,12 +57,13 @@ def retrosynthesis_process(smiles):
for idx, route in enumerate(solved_routes, 1):
retro_route = RetroRoute(route['dict'])
retro_route.find_child_nodes2(retro_route.route_dict)
route_dic = {'score': route['all_score']['state score'], 'steps': retro_route.reactions,
'depth': route['node'].state.max_transforms}
solved_route_dict.update({f'Route {idx}': route_dic})
route_dic = {
'score': route['all_score']['state score'],
'steps': retro_route.reactions,
'depth': route['node'].state.max_transforms,
}
solved_route_dict[f"Route {idx}"] = route_dic
route_dicts = routes.dicts[0:10]
raw_routes = []
for idx, route_dict in enumerate(route_dicts, 1):
raw_routes.append(route_dict)
raw_routes = [route_dict for route_dict in route_dicts]

return solved_route_dict, raw_routes