diff --git a/aizynthfinder/aizynthfinder.py b/aizynthfinder/aizynthfinder.py index c8255e9..f9f6abb 100644 --- a/aizynthfinder/aizynthfinder.py +++ b/aizynthfinder/aizynthfinder.py @@ -228,6 +228,7 @@ def tree_search(self, show_progress: bool = False) -> float: try: is_solved = self.tree.one_iteration() + self.config.search.algorithm_config["current_iteration"] += 1 except StopIteration: break diff --git a/aizynthfinder/context/config.py b/aizynthfinder/context/config.py index c55d194..c8ecf1d 100644 --- a/aizynthfinder/context/config.py +++ b/aizynthfinder/context/config.py @@ -41,9 +41,11 @@ class _SearchConfiguration: "immediate_instantiation": (), "mcts_grouping": None, "search_rewards_weights": [], + "enhancement": "Default", + "current_iteration": 0, } ) - max_transforms: int = 6 + max_transforms: int = 10 iteration_limit: int = 100 time_limit: int = 120 return_first: bool = False diff --git a/aizynthfinder/search/mcts/node.py b/aizynthfinder/search/mcts/node.py index 063cd9a..94f4956 100644 --- a/aizynthfinder/search/mcts/node.py +++ b/aizynthfinder/search/mcts/node.py @@ -70,6 +70,7 @@ def __init__( self.is_expanded: bool = False self.is_expandable: bool = not self.state.is_terminal self._parent = parent + self.depth = 0 if owner is None: self.created_at_iteration: Optional[int] = None @@ -249,13 +250,34 @@ def expand(self) -> None: for child in self.parent.children: if child is not self: cache_molecules.extend(child.state.expandable_mols) - + if self.parent: + self.depth = self.parent.depth + 1 + if not self.parent: + self.depth = 0 # Calculate the possible actions, fill the child_info lists # Actions by default only assumes 1 set of reactants actions, priors = self._expansion_policy( self.state.expandable_mols, cache_molecules ) - self._fill_children_lists(actions, priors) + enhancement = self._config.search.algorithm_config["enhancement"] + n = 0 + if enhancement == "Default": + self._fill_children_lists(actions, priors) + elif enhancement == "eUCT": + self._fill_children_lists(actions, priors) + print('through here') + elif enhancement == "dUCT-v1" or "dUCT-v2": + if enhancement == "dUCT-v1": + n = 20 + elif enhancement == "dUCT-v2": + n = 50 + action_prior_pairs = list(zip(actions, priors)) + sorted_action_prior_pairs = sorted(action_prior_pairs, key=lambda x: x[1], reverse=True) + top_n_action_prior_pairs = sorted_action_prior_pairs[:n] + actions[:], priors[:] = zip(*top_n_action_prior_pairs) + actions = list(actions) + priors = list(priors) + self._fill_children_lists(actions, priors) # Reverse the expansion if it did not produce any children if len(actions) == 0: @@ -377,6 +399,23 @@ def _children_u(self) -> np.ndarray: child_visits = np.array(self._children_visitations) return self._algo_config["C"] * np.sqrt(2 * total_visits / child_visits) + def _children_du(self, depth) -> np.ndarray: + total_visits = np.log(np.sum(self._children_visitations)) + child_visits = np.array(self._children_visitations) + dc = self._algo_config["C"] + 0.5 * depth + return dc * np.sqrt(2 * total_visits / child_visits) + + def _children_eu(self) -> np.ndarray: + total_visits = np.log(np.sum(self._children_visitations)) + child_visits = np.array(self._children_visitations) + current_iteration = self._algo_config["current_iteration"] + if current_iteration == 0: + phi = (1 / self._config.search.iteration_limit)/2 + else: + phi = (current_iteration/self._config.search.iteration_limit)/2 + lamda = (self._algo_config["C"]/(phi+1)) + return lamda * np.sqrt(2 * total_visits / child_visits) + def _create_children_nodes( self, states: List[MctsState], child_idx: int ) -> List["MctsNode"]: @@ -536,7 +575,13 @@ def _regenerated_blacklisted(self, reaction: RetroReaction) -> bool: def _score_and_select(self) -> Optional["MctsNode"]: if not max(self._children_values) > 0: raise ValueError("Has no selectable children") - scores = self._children_q() + self._children_u() + + if self._algo_config["enhancement"] == "dUCT-v1" or "dUCT-v2": + scores = self._children_q() + self._children_du(self.depth) + elif self._algo_config["enhancement"] == "eUCT": + scores = self._children_q() + self._children_eu() + elif self._algo_config["enhancement"] == "Default": + scores = self._children_q() + self._children_u() indices = np.where(scores == scores.max())[0] index = np.random.choice(indices) return self._select_child(index) diff --git a/sources/retrosynthesis/config_files/zinc_and_emol_inchi_key.bloom b/sources/retrosynthesis/config_files/zinc_and_emol_inchi_key.bloom index e69de29..98f50f7 100644 Binary files a/sources/retrosynthesis/config_files/zinc_and_emol_inchi_key.bloom and b/sources/retrosynthesis/config_files/zinc_and_emol_inchi_key.bloom differ diff --git a/sources/retrosynthesis/routes.py b/sources/retrosynthesis/routes.py index d129e2f..3518c77 100644 --- a/sources/retrosynthesis/routes.py +++ b/sources/retrosynthesis/routes.py @@ -19,30 +19,34 @@ def service_check(): def retrosynthesis(): access_key = str(request.args.get('key')) if access_key != ACCESS_KEY: - print("invalid key") - return json.dumps({'Message': 'invalid key', 'Timestamp': time.time()}) + print("Invalid key") + return json.dumps({'Message': 'Invalid key', 'Timestamp': time.time()}) + smiles = str(request.args.get('smiles')) - solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles) + enhancement = str(request.args.get('enhancement', 'Default')) + finder = sources.retrosynthesis.startup.make_config() + finder.config.search.algorithm_config["enhancement"] = enhancement + solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles, finder) page_data = {'Message': solved_route_dict_ls, 'Raw_Routes': raw_routes, 'Timestamp': time.time()} json_dump = json.dumps(page_data) return json_dump -def retrosynthesis_process(smiles): - """ - Takes a smiles string and returns a list of retrosynthetic routes stored as dictionaries +def retrosynthesis_process(smiles, finder): """ - # load config containing policy file locations - print(smiles) + Takes a SMILES string and a pre-configured finder object and returns a list of retrosynthetic routes as dictionaries. + """ + print(f"Running retrosynthesis for SMILES: {smiles}") + from rdkit import Chem from aizynthfinder.interfaces import aizynthcli from sources.retrosynthesis.classes import RetroRoute - # from sources.retrosynthesis.startup import finder - finder = sources.retrosynthesis.startup.make_config() + mol = Chem.MolFromSmiles(smiles) - print(mol) + if not mol: + raise ValueError("Invalid SMILES string") + print(f"Molecule generated: {mol}") aizynthcli._process_single_smiles(smiles, finder, None, False, None, [], None) - # Find solved routes and process routes objects into list of dictionaries routes = finder.routes solved_routes = [] for idx, node in enumerate(routes.nodes): @@ -53,12 +57,13 @@ def retrosynthesis_process(smiles): for idx, route in enumerate(solved_routes, 1): retro_route = RetroRoute(route['dict']) retro_route.find_child_nodes2(retro_route.route_dict) - route_dic = {'score': route['all_score']['state score'], 'steps': retro_route.reactions, - 'depth': route['node'].state.max_transforms} - solved_route_dict.update({f'Route {idx}': route_dic}) + route_dic = { + 'score': route['all_score']['state score'], + 'steps': retro_route.reactions, + 'depth': route['node'].state.max_transforms, + } + solved_route_dict[f"Route {idx}"] = route_dic route_dicts = routes.dicts[0:10] - raw_routes = [] - for idx, route_dict in enumerate(route_dicts, 1): - raw_routes.append(route_dict) + raw_routes = [route_dict for route_dict in route_dicts] return solved_route_dict, raw_routes