From 9dca93630862e4cf1f6b9681024fb69787f0b9ac Mon Sep 17 00:00:00 2001 From: tonbl Date: Tue, 30 Sep 2025 14:16:01 +0100 Subject: [PATCH 1/3] added batch SMILES input, .txt file acceptor, progress bar, multi-SMILES functionality, and includes first instance of post softmax grouping of templates --- aizynthfinder/context/config.py | 2 +- .../context/policy/expansion_strategies.py | 292 +++++++++++++++++- .../context/policy/post_sm_grouping.py | 113 +++++++ sources/retrosynthesis/routes.py | 216 +++++++++---- sources/retrosynthesis/startup.py | 31 +- 5 files changed, 572 insertions(+), 82 deletions(-) create mode 100644 aizynthfinder/context/policy/post_sm_grouping.py diff --git a/aizynthfinder/context/config.py b/aizynthfinder/context/config.py index 01ad212..10c22c4 100644 --- a/aizynthfinder/context/config.py +++ b/aizynthfinder/context/config.py @@ -35,7 +35,7 @@ class _SearchConfiguration: default_factory=lambda: { "C": 1.4, "default_prior": 0.5, - "use_prior": False, + "use_prior": True, "prune_cycles_in_search": True, "search_rewards": ["state score"], "immediate_instantiation": (), diff --git a/aizynthfinder/context/policy/expansion_strategies.py b/aizynthfinder/context/policy/expansion_strategies.py index 83e10f6..0020e76 100644 --- a/aizynthfinder/context/policy/expansion_strategies.py +++ b/aizynthfinder/context/policy/expansion_strategies.py @@ -14,6 +14,8 @@ from aizynthfinder.utils.exceptions import PolicyException from aizynthfinder.utils.logging import logger from aizynthfinder.utils.models import load_model +from aizynthfinder.context.policy.post_sm_grouping import compute_center_group_key + if TYPE_CHECKING: from aizynthfinder.chem import TreeMolecule @@ -59,6 +61,7 @@ def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: self._logger = logger() self.key = key + def __call__( self, molecules: Sequence[TreeMolecule], @@ -304,12 +307,20 @@ def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: self._load_mask_file(maskfile) if maskfile else None ) - if hasattr(self.model, "output_size") and len(self.templates) != self.model.output_size: # type: ignore + if hasattr(self.model, "output_size") and len(self.templates) != self.model.output_size: raise PolicyException( - f"The number of templates ({len(self.templates)}) does not agree with the " # type: ignore + f"The number of templates ({len(self.templates)}) does not agree with the " f"output dimensions of the model ({self.model.output_size})" ) self._cache: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} + self._cutoff_impl_name: str = str(kwargs.get("cutoff_impl", "templates")).lower() + self._cutoff_impl = getattr(self, f"_cutoff_predictions_{self._cutoff_impl_name}", None) + if self._cutoff_impl is None: + raise PolicyException(f"Unknown cutoff_impl='{self._cutoff_impl_name}'. Use 'templates' or 'groups'.") + + self._group_index = None + if self._cutoff_impl_name == "groups": + self._build_group_index() def get_actions( self, @@ -358,31 +369,284 @@ def reset_cache(self) -> None: """Reset the prediction cache""" self._cache = {} + def _build_group_index(self) -> None: + """ + Build a mapping: group_key -> list of row positions of N templates. + Uses the template SMARTS to compute a product-side reacting-center + first shell signature. + """ + df = self.templates + + if self.template_column in df.columns: + smarts_col = self.template_column + else: + smarts_col = None + for c in ("retro_template", "template", "smarts", "reaction_smarts"): + if c in df.columns: + smarts_col = c + break + if smarts_col is None: + self._logger.info("Grouping: no SMARTS column found; skipping") + self._group_index = None + return + + print(f"[DEBUG groups] computing group keys for {len(df)} templates " + f"using column '{smarts_col}'") + if "group_key" not in df.columns: + print("[DEBUG groups] no existing group_key column; computing…") + df["group_key"] = df[smarts_col].fillna("").map(compute_center_group_key) + + from collections import defaultdict + gi = defaultdict(list) + for pos, g in enumerate(df["group_key"].values): + gi[g].append(pos) + self._group_index = gi + + ng = len(gi) + avg = len(df) / max(ng, 1) + print(f"[DEBUG groups] built {ng} groups over {len(df)} templates " + f"(avg {avg:.2f}/group)") + + sizes = sorted(((k, len(v)) for k, v in gi.items()), key=lambda x: x[1], reverse=True)[:10] + if sizes: + show = ", ".join([f"{k[:10]}…:{n}" if isinstance(k, str) else f"{str(k)[:10]}…:{n}" for k, n in sizes]) + print(f"[DEBUG groups] largest groups (key:count): {show}") + + + def _dbg_row_info(self, pos: int) -> str: + """Build a short string with template identity for position `pos`.""" + df = self.templates + try: + label = df.index.values[pos] + except Exception: + label = pos + get = lambda col: (str(df.iloc[pos][col]) if col in df.columns else "") + thash = get("template_hash") + gkey = get("group_key") + # abbreviate long hashes + if isinstance(thash, str) and len(thash) > 10: + thash = thash[:10] + "…" + if isinstance(gkey, str) and len(gkey) > 10: + gkey = gkey[:10] + "…" + return f"pos={pos} label={label} hash={thash} group={gkey}" + + def _dbg_dump_selection(self, preds: np.ndarray, selected: np.ndarray, tag: str) -> None: + """Print the final selected templates (bounded).""" + if selected is None or len(selected) == 0: + print(f"[DEBUG cutoff:{tag}] no templates selected") + return + topn = min(25, len(selected)) + total = float(np.nansum(preds)) + cum = float(np.nansum(preds[selected])) + print(f"[DEBUG cutoff:{tag}] selected {len(selected)} templates " + f"(mass {cum:.6g} of {total:.6g}); showing top {topn}") + # order by prob descending + order = np.argsort(preds[selected])[::-1] + for rank, rel in enumerate(order[:topn], start=1): + pos = int(selected[rel]) + p = float(preds[pos]) + print(f" #{rank:>2} p={p:.6g} {self._dbg_row_info(pos)}") + def _cutoff_predictions(self, predictions: np.ndarray) -> np.ndarray: + """ + Select up to TOTAL_TARGET templates: + If GROUPING: + 1) Rank groups by total probability; keep top NUM_GROUPS groups. + 2) From each kept group, take PER_GROUP templates + 3) Top up to TOTAL_TARGET with global best + Else: + Take global top TOTAL_TARGET templates. + """ + + # ---------------------------------- + GROUPING = True + NUM_GROUPS = 20 + PER_GROUP = 5 + TOTAL_TARGET = 100 + FILL_FROM_GLOBAL = True + # ---------------------------------- + + preds = predictions.copy().astype(float) + + if getattr(self, "mask", None) is not None: + preds[~self.mask] = 0.0 + + all_idx = np.arange(preds.size, dtype=np.int32) + global_order = all_idx[np.argsort(preds)[::-1]] + if not GROUPING: + return global_order[: min(TOTAL_TARGET, global_order.size)] + if getattr(self, "_group_index", None) is None: + self._build_group_index() + gi = getattr(self, "_group_index", None) + + if not gi or NUM_GROUPS <= 0 or PER_GROUP <= 0: + return global_order[: min(TOTAL_TARGET, global_order.size)] + + group_items = list(gi.items()) + group_sums = np.array( + [float(preds[np.asarray(members, dtype=np.int32)].sum()) for _, members in group_items], + dtype=float, + ) + + # Order groups by total mass + g_order = np.argsort(group_sums)[::-1] + + selected: list[int] = [] + + # Take top NUM_GROUPS groups + for g_pos in g_order[: min(NUM_GROUPS, len(g_order))]: + _, members = group_items[g_pos] + members = np.asarray(members, dtype=np.int32) + if members.size == 0: + continue + + # Top PER_GROUP within this group + local_sorted = members[np.argsort(preds[members])[::-1]] + take_k = int(min(PER_GROUP, local_sorted.size)) + + count = 0 + for idx in local_sorted: + if count >= take_k: + break + if preds[idx] > 0.0: + selected.append(int(idx)) + count += 1 + + if len(selected) >= TOTAL_TARGET: + break + + # Top up from global templates if needed + if FILL_FROM_GLOBAL and len(selected) < TOTAL_TARGET: + chosen = set(selected) + for idx in global_order: + if len(selected) >= TOTAL_TARGET: + break + if idx not in chosen and preds[idx] > 0.0: + selected.append(int(idx)) + + if not selected: + return global_order[: min(TOTAL_TARGET, global_order.size)] + if len(selected) > TOTAL_TARGET: + selected = selected[:TOTAL_TARGET] + + return np.asarray(selected, dtype=np.int32) + + def _cutoff_predictions_templates(self, predictions: np.ndarray) -> np.ndarray: """ Get the top transformations, by selecting those that have: * cumulative probability less than a threshold (cutoff_cumulative) * or at most N (cutoff_number) """ + # # DEBUG + # TARGET_HASH = "0c33fbaaea663fe5aa40102c16fffe484773b331031394dbfad9e1eace47a3c2" + # + # df = None + # if hasattr(self, "templates"): + # df = self.templates + # elif hasattr(self, "_template_library") and hasattr(self._template_library, "data"): + # df = self._template_library.data + # + # if df is not None: + # row = df.loc[df["template_hash"] == TARGET_HASH] + # if len(row) == 1: + # try: + # iloc = int(df.index.get_indexer([row.index[0]])[0]) + # except Exception: + # code = int(row["template_code"].iloc[0]) + # iloc = int(df.index.get_indexer([code])[0]) + # p = float(predictions[iloc]) + # rank = int((predictions > p).sum() + 1) + # pct = 100.0 * rank / predictions.size + # masked = (getattr(self, "mask", None) is not None and not self.mask[iloc]) + # print(f"[POLICY] prob={p:.6g} rank={rank}/{predictions.size} (top {pct:.4f}%)" + # + (" [MASKED]" if masked else "")) + # # END + + preds = predictions.copy() if self.mask is not None: - predictions[~self.mask] = 0 - sortidx = np.argsort(predictions)[::-1] - cumsum: np.ndarray = np.cumsum(predictions[sortidx]) + preds[~self.mask] = 0 + + sortidx = np.argsort(preds)[::-1] + cumsum: np.ndarray = np.cumsum(preds[sortidx]) + if any(cumsum >= self.cutoff_cumulative): maxidx = int(np.argmin(cumsum < self.cutoff_cumulative)) else: maxidx = len(cumsum) + maxidx = min(maxidx, self.cutoff_number) or 1 - return sortidx[:maxidx] - def _load_mask_file(self, maskfile: str) -> np.ndarray: - self._logger.info(f"Loading masking of templates from {maskfile} to {self.key}") - mask = np.load(maskfile)["arr_0"] - if len(mask) != len(self.templates): - raise PolicyException( - f"The number of masks {len(mask)} does not match the number of templates {len(self.templates)}" - ) - return mask + selected = sortidx[:maxidx] + try: + self._dbg_dump_selection(preds, selected, tag="templates") + except Exception as _e: + print(f"[DEBUG cutoff:templates] printing failed: {_e}") + return selected + + def _cutoff_predictions_groups(self, predictions: np.ndarray) -> np.ndarray: + """ + Group-level cutoff: + 1) Sum probabilities per group. + 2) Select top groups by cumulative probability and cutoff_number (applied to groups). + 3) Return template indices from selected groups, sorted by their individual prob, + then truncate to self.cutoff_number templates. + """ + + if getattr(self, "_group_index", None) is None: + self._build_group_index() + + preds = predictions.copy() + if self.mask is not None: + preds[~self.mask] = 0 + + gi = getattr(self, "_group_index", None) + if not gi: + print("[DEBUG cutoff:groups] no group index; falling back to template cutoff") + return self._cutoff_predictions_templates(preds) + + group_keys = list(gi.keys()) + group_members = [gi[k] for k in group_keys] + group_sums = np.array([float(preds[m].sum()) for m in group_members], dtype=float) + + g_order = np.argsort(group_sums)[::-1] + g_cumsum = np.cumsum(group_sums[g_order]) + + if any(g_cumsum >= self.cutoff_cumulative): + gmax = int(np.argmin(g_cumsum < self.cutoff_cumulative)) + else: + gmax = len(g_order) + + gmax = min(gmax, self.cutoff_number) or 1 + + print(f"[DEBUG cutoff:groups] considering {len(group_keys)} groups; " + f"selecting top {gmax} by group mass (cum ≤ {self.cutoff_cumulative})") + show_g = min(10, len(g_order)) + for rank, gi_pos in enumerate(g_order[:show_g], start=1): + key = group_keys[gi_pos] + members = group_members[gi_pos] + mass = float(group_sums[gi_pos]) + print(f" [group #{rank:>2}] mass={mass:.6g} key={str(key)[:10]}… |members|={len(members)}") + + kept_indices = [] + for i in g_order[:gmax]: + kept_indices.extend(group_members[i]) + + if not kept_indices: + print("[DEBUG cutoff:groups] kept_indices empty; falling back to template cutoff") + return self._cutoff_predictions_templates(preds) + + kept_indices = np.array(kept_indices, dtype=np.int32) + order = np.argsort(preds[kept_indices])[::-1] + out = kept_indices[order] + + if len(out) > self.cutoff_number: + out = out[: self.cutoff_number] + + try: + self._dbg_dump_selection(preds, out, tag="groups") + except Exception as _e: + print(f"[DEBUG cutoff:groups] printing failed: {_e}") + + return out def _update_cache(self, molecules: Sequence[TreeMolecule]) -> None: pred_inchis = [] diff --git a/aizynthfinder/context/policy/post_sm_grouping.py b/aizynthfinder/context/policy/post_sm_grouping.py new file mode 100644 index 0000000..1958983 --- /dev/null +++ b/aizynthfinder/context/policy/post_sm_grouping.py @@ -0,0 +1,113 @@ +from rdkit import Chem +from rdkit.Chem import rdchem +import hashlib + +try: + from rdkit.Chem import rdChemReactions as _rxnmod +except Exception: + from rdkit.Chem import AllChem as _rxnmod + +def _rxn_from_smarts(s): + try: + return _rxnmod.ReactionFromSmarts(str(s), useSmiles=False) + except Exception: + return None + +def _bond_set(mols): + """Return set of (mapA,mapB,bondType,isAromatic) for all bonds in a list of mols. + Only includes bonds whose atoms both have map numbers.""" + out = set() + for m in mols: + for b in m.GetBonds(): + a = b.GetBeginAtom(); z = b.GetEndAtom() + ma = a.GetAtomMapNum(); mz = z.GetAtomMapNum() + if ma and mz: + a1, a2 = (ma, mz) if ma < mz else (mz, ma) + out.add((a1, a2, int(b.GetBondType()), b.GetIsAromatic())) + return out + +def _reacting_maps(rxn): + """Map numbers of atoms participating in any bond change.""" + r_bonds = _bond_set(list(rxn.GetReactants())) + p_bonds = _bond_set(list(rxn.GetProducts())) + changed = (r_bonds ^ p_bonds) + maps = set() + for a1, a2, *_ in changed: + maps.add(a1); maps.add(a2) + return maps, r_bonds, p_bonds + +def _first_shell_signature(prod_mols, center_maps): + """Signature = tuple of atoms (with features) and bonds in subgraph: + center atoms + neighbors (radius 1) on product side.""" + idx = {} + for m in prod_mols: + for a in m.GetAtoms(): + amap = a.GetAtomMapNum() + if amap: + idx[amap] = (m, a.GetIdx()) + + atom_keys = [] + bond_keys = set() + + def atom_feat(a: rdchem.Atom): + try: + hyb = int(a.GetHybridization()) + except Exception: + hyb = -1 + return ( + a.GetSymbol(), + a.GetIsAromatic(), + a.IsInRing(), + a.GetFormalCharge(), + hyb, + ) + + queue = set(center_maps) + for cmap in list(center_maps): + m, ai = idx.get(cmap, (None, None)) + if m is None: + continue + a = m.GetAtomWithIdx(ai) + for nb in a.GetNeighbors(): + if nb.GetAtomMapNum(): + queue.add(nb.GetAtomMapNum()) + + for amap in queue: + if amap in idx: + m, ai = idx[amap] + a = m.GetAtomWithIdx(ai) + atom_keys.append((amap, atom_feat(a))) + + amap_set = set(queue) + for m in prod_mols: + for b in m.GetBonds(): + a = b.GetBeginAtom(); z = b.GetEndAtom() + ma = a.GetAtomMapNum(); mz = z.GetAtomMapNum() + if ma and mz and (ma in amap_set) and (mz in amap_set): + a1, a2 = (ma, mz) if ma < mz else (mz, ma) + bond_keys.add((a1, a2, int(b.GetBondType()), b.GetIsAromatic())) + + atoms_canon = tuple((feat) for _, feat in sorted(atom_keys, key=lambda x: x[0])) + bonds_canon = tuple(sorted([(t, aro) for _, _, t, aro in bond_keys])) + + return atoms_canon, bonds_canon + +def compute_center_group_key(template_smarts: str) -> str: + """ + group key: exact reacting center + first shell on product side AND bond-change multiset. + """ + rxn = _rxn_from_smarts(template_smarts) + if rxn is None: + return hashlib.sha256(("RAW|" + template_smarts).encode()).hexdigest() + + center_maps, r_bonds, p_bonds = _reacting_maps(rxn) + + formed = tuple(sorted(p_bonds - r_bonds)) + broken = tuple(sorted(r_bonds - p_bonds)) + changed = tuple(sorted(r_bonds & p_bonds)) + + atoms_sig, bonds_sig = _first_shell_signature(list(rxn.GetProducts()), center_maps) + + payload = ("L1|", atoms_sig, bonds_sig, formed, broken, changed) + raw = repr(payload).encode() + return hashlib.sha256(raw).hexdigest() diff --git a/sources/retrosynthesis/routes.py b/sources/retrosynthesis/routes.py index fd14b65..61faf30 100644 --- a/sources/retrosynthesis/routes.py +++ b/sources/retrosynthesis/routes.py @@ -1,75 +1,177 @@ import os import json import time -from flask import request +from flask import request, jsonify +from pathlib import Path import sources.retrosynthesis.startup from sources import app -ACCESS_KEY = os.getenv("KEY", 'retro_key') +ACCESS_KEY = os.getenv("KEY", "retro_key") +BATCH_OUT_DIR = Path(os.getenv("results", ".")) +BATCH_OUT_FILE = BATCH_OUT_DIR / "trees_batch.json" +BATCH_PROGRESS = {} -@app.route('/', methods=['GET']) +@app.route("/", methods=["GET"]) def service_check(): - finder = sources.retrosynthesis.startup.make_config() - # stock1 = finder.config.stock - page_data = {'Message': 'Retrosynthesis service is running', 'Timestamp': time.time(), 'Version': 'v2.05.25'} - json_dump = json.dumps(page_data) - return json_dump + _ = sources.retrosynthesis.startup.make_config() + page_data = { + "Message": "Retrosynthesis service is running", + "Timestamp": time.time(), + "Version": "v2.05.25", + } + return jsonify(page_data) -@app.route('/retrosynthesis_api/', methods=['GET']) +@app.route("/retrosynthesis_api/", methods=["GET"]) def retrosynthesis(): - access_key = str(request.args.get('key')) - if access_key != ACCESS_KEY: - print("Invalid key") - return json.dumps({'Message': 'Invalid key', 'Timestamp': time.time()}) + access_key = str(request.args.get("key", "")) + if access_key != ACCESS_KEY: + return jsonify({"Message": "Invalid key", "Timestamp": time.time()}), 401 - smiles = str(request.args.get('smiles')) - enhancement = str(request.args.get('enhancement', 'Default')) - finder = sources.retrosynthesis.startup.make_config() - finder.config.search.algorithm_config["enhancement"] = enhancement - finder.config.search.iteration_limit = int(request.args.get('iterations')) - finder.config.search.max_transforms = int(request.args.get('max_depth')) - finder.config.search.time_limit = int(request.args.get('time_limit')) - solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles, finder) - page_data = {'Message': solved_route_dict_ls, 'Raw_Routes': raw_routes, 'Timestamp': time.time()} + smiles = str(request.args.get("smiles", "")).strip() + if not smiles: + return jsonify({"Message": "Missing 'smiles'", "Timestamp": time.time()}), 400 - json_dump = json.dumps(page_data) - return json_dump + enhancement = str(request.args.get("enhancement", "Default")) + iterations = int(request.args.get("iterations", 100)) + max_depth = int(request.args.get("max_depth", 7)) + time_limit = int(request.args.get("time_limit", 60)) + + finder = sources.retrosynthesis.startup.make_config() + finder.config.search.algorithm_config["enhancement"] = enhancement + finder.config.search.iteration_limit = iterations + finder.config.search.max_transforms = max_depth + finder.config.search.time_limit = time_limit + + solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles, finder) + page_data = { + "Message": solved_route_dict_ls, + "Raw_Routes": raw_routes, + "Timestamp": time.time(), + } + return jsonify(page_data) def retrosynthesis_process(smiles, finder): - """ - Takes a SMILES string and a pre-configured finder object and returns a list of retrosynthetic routes as dictionaries. """ - print(f"Running retrosynthesis for SMILES: {smiles}") - - from rdkit import Chem - from aizynthfinder.interfaces import aizynthcli - from sources.retrosynthesis.classes import RetroRoute - - mol = Chem.MolFromSmiles(smiles) - if not mol: - raise ValueError("Invalid SMILES string") - print(f"Molecule generated: {mol}") - aizynthcli._process_single_smiles(smiles, finder, None, False, None, [], None) - routes = finder.routes - solved_routes = [] - for idx, node in enumerate(routes.nodes): - if node.is_solved is True: - solved_routes.append(routes[idx]) - solved_routes = solved_routes[0:10] - solved_route_dict = {} - for idx, route in enumerate(solved_routes, 1): - retro_route = RetroRoute(route['dict']) - retro_route.find_child_nodes2(retro_route.route_dict) - route_dic = { - 'score': route['all_score']['state score'], - 'steps': retro_route.reactions, - 'depth': route['node'].state.max_transforms, - } - solved_route_dict[f"Route {idx}"] = route_dic - route_dicts = routes.dicts[0:10] - raw_routes = [route_dict for route_dict in route_dicts] - - return solved_route_dict, raw_routes \ No newline at end of file + Takes a SMILES string and a pre-configured finder object and returns + a list of retrosynthetic routes as dictionaries. + """ + print(f"Running retrosynthesis for SMILES: {smiles}") + + from rdkit import Chem + from aizynthfinder.interfaces import aizynthcli + from sources.retrosynthesis.classes import RetroRoute + + mol = Chem.MolFromSmiles(smiles) + if not mol: + raise ValueError("Invalid SMILES string") + + aizynthcli._process_single_smiles(smiles, finder, None, False, None, [], None) + + routes = finder.routes + solved_routes = [] + for idx, node in enumerate(routes.nodes): + if node.is_solved is True: + solved_routes.append(routes[idx]) + + solved_routes = solved_routes[0:10] + solved_route_dict = {} + for idx, route in enumerate(solved_routes, 1): + retro_route = RetroRoute(route["dict"]) + retro_route.find_child_nodes2(retro_route.route_dict) + route_dic = { + "score": route["all_score"]["state score"], + "steps": retro_route.reactions, + "depth": route["node"].state.max_transforms, + } + solved_route_dict[f"Route {idx}"] = route_dic + + route_dicts = routes.dicts[0:10] + raw_routes = [route_dict for route_dict in route_dicts] + + return solved_route_dict, raw_routes + +@app.route("/retrosynthesis_batch_progress/", methods=["GET"]) +def retrosynthesis_batch_progress(batch_id): + info = BATCH_PROGRESS.get(batch_id) + if not info: + return jsonify({"status": "unknown"}), 404 + return jsonify(info) + +@app.route("/retrosynthesis_batch/", methods=["POST", "OPTIONS"]) +def retrosynthesis_batch(): + if request.method == "OPTIONS": + return ("", 204) + + payload = request.get_json(silent=True) or {} + access_key = str(payload.get("key", "")) + if access_key != ACCESS_KEY: + return jsonify({"Message": "Invalid key", "Timestamp": time.time()}), 401 + + smiles_list = payload.get("smiles_list") or [] + if not isinstance(smiles_list, list) or not smiles_list: + return jsonify({"Message": "smiles_list must be a non-empty list", "Timestamp": time.time()}), 400 + + enhancement = str(payload.get("enhancement", "Default")) + iterations = int(payload.get("iterations", 100)) + max_depth = int(payload.get("max_depth", 7)) + time_limit = int(payload.get("time_limit", 60)) + + batch_id = str(payload.get("batch_id", "")) or str(int(time.time())) + BATCH_PROGRESS[batch_id] = {"total": len(smiles_list), "current": 0, "smiles": "", "status": "running"} + + results = [] + try: + for idx, smi in enumerate(smiles_list, start=1): + BATCH_PROGRESS[batch_id].update({"current": idx, "smiles": smi, "status": "running"}) + + try: + finder = sources.retrosynthesis.startup.make_config() + finder.config.search.algorithm_config["enhancement"] = enhancement + finder.config.search.iteration_limit = iterations + finder.config.search.max_transforms = max_depth + finder.config.search.time_limit = time_limit + + solved_route_dict_ls, raw_routes = retrosynthesis_process(smi, finder) + results.append( + {"smiles": smi, "status": "ok", "routes": solved_route_dict_ls, "raw_routes": raw_routes}) + except Exception as e: + results.append({"smiles": smi, "status": "failed", "error": str(e), "routes": {}, "raw_routes": []}) + + BATCH_PROGRESS[batch_id].update({"status": "done"}) + except Exception: + BATCH_PROGRESS[batch_id].update({"status": "failed"}) + raise + + batch_entries = [] + for item in results: + batch_entries.append({ + "target": item["smiles"], + "status": item.get("status", "ok"), + "routes": item.get("routes", {}), + "raw_routes": item.get("raw_routes", []), + "timestamp": time.time(), + }) + BATCH_OUT_DIR.mkdir(parents=True, exist_ok=True) + existing = [] + if BATCH_OUT_FILE.exists(): + try: + with BATCH_OUT_FILE.open("r", encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, dict) and "entries" in data and isinstance(data["entries"], list): + existing = data["entries"] + elif isinstance(data, list): + existing = data + except Exception: + existing = [] + combined = existing + batch_entries + tmp_path = BATCH_OUT_FILE.with_suffix(".json.tmp") + with tmp_path.open("w", encoding="utf-8") as f: + json.dump(combined, f, ensure_ascii=False, indent=2) + tmp_path.replace(BATCH_OUT_FILE) + + return jsonify( + {"Message": "Batch completed", "results": results, "Timestamp": time.time()} + ) diff --git a/sources/retrosynthesis/startup.py b/sources/retrosynthesis/startup.py index 0c208f5..f06d6fa 100644 --- a/sources/retrosynthesis/startup.py +++ b/sources/retrosynthesis/startup.py @@ -10,10 +10,14 @@ def make_config(): stock_request = request.args.get('stock') - stock_file_dict = {'emolecules': 'emolecules.hdf5', - 'zinc': 'zinc_stock_17_04_20.hdf5'} + stock_file_dict = { + 'zinc': 'zinc_stock_17_04_20.hdf5', + 'overlay': 'stock_additions_2025_08.hdf5', + 'emols': 'zinc_and_emol_inchi_key.bloom', + 'alcohols': 'dummy_alcohols.hdf5', + 'naturals': 'np_08_25.hdf5', + 'non_iso_naturals': 'ninp_08_25.hdf5'} - chosen_stock_file = stock_file_dict.get(stock_request, 'bloom') AIZYNTH = { 'expansion': { @@ -24,9 +28,18 @@ def make_config(): } }, - #'stock': {'zinc': os.path.join(BASEDIR, 'config_files', 'zinc_stock_17_04_20.hdf5')} - 'stock': {'bloom': os.path.join(BASEDIR, 'config_files', 'zinc_and_emol_inchi_key.bloom')} - , + 'stock': { + 'zinc': os.path.join(BASEDIR, 'config_files', 'zinc_stock_17_04_20.hdf5'), + 'overlay': os.path.join(BASEDIR, 'config_files', 'stock_additions_2025_08.hdf5'), + 'alcohols': os.path.join(BASEDIR, 'config_files', 'dummy_alcohols.hdf5'), + 'naturals': os.path.join(BASEDIR, 'config_files', 'np_08_25.hdf5'), + 'non_iso_naturals': os.path.join(BASEDIR, 'config_files', 'ninp_08_25.hdf5'), + 'paroutes_n1': os.path.join(BASEDIR, 'config_files', 'n1_stock.hdf5'), + 'paroutes_n5': os.path.join(BASEDIR, 'config_files', 'n5_stock.hdf5'), + 'askos_stock': os.path.join(BASEDIR, 'config_files', 'askcos_stock.hdf5') + }, + + #'stock': {'bloom': os.path.join(BASEDIR, 'config_files', 'zinc_and_emol_inchi_key.bloom')}, 'config_file': os.path.join(BASEDIR, 'config_files', 'aizynthfinder_config.yml'), 'properties': { 'max_transforms': 10, @@ -39,18 +52,16 @@ def make_config(): print(AIZYNTH) aizynth_config_dic = AIZYNTH print(2) - # initiate object containing all required arguments args = AiZynthArgs("placeholder", aizynth_config_dic['expansion'], aizynth_config_dic['stock']) + args.stock = stock_request or 'zinc,overlay' print(3) - # AiZynthFinder object contains results data finder = aizynthfinder.AiZynthFinder(configdict=aizynth_config_dic) print(4) - # set up stocks, policies, then start single smiles process aizynthcli._select_stocks(finder, args) print(5) finder.expansion_policy.select(args.policy or finder.expansion_policy.items[0]) print(6) finder.filter_policy.select(args.filter) print(7) - return finder \ No newline at end of file + return finder From 4e18dd408534500392c574945408125dc62f5bf7 Mon Sep 17 00:00:00 2001 From: tonbl Date: Mon, 20 Oct 2025 16:18:28 +0100 Subject: [PATCH 2/3] recheck required --- .../context/policy/expansion_strategies.py | 207 ++++++++++++------ sources/retrosynthesis/routes.py | 73 +++++- sources/retrosynthesis/startup.py | 2 +- 3 files changed, 202 insertions(+), 80 deletions(-) diff --git a/aizynthfinder/context/policy/expansion_strategies.py b/aizynthfinder/context/policy/expansion_strategies.py index 0020e76..49051f6 100644 --- a/aizynthfinder/context/policy/expansion_strategies.py +++ b/aizynthfinder/context/policy/expansion_strategies.py @@ -290,6 +290,12 @@ def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: self.rescale_prior: bool = bool(kwargs.get("rescale_prior", False)) self.chiral_fingerprints = bool(kwargs.get("chiral_fingerprints", False)) + self.grouping_mode = str(kwargs.get("grouping_mode", "auto")).lower() + self.group_per_group_k = int(kwargs.get("group_per_group_k", 2)) + self.total_target_templates = int(kwargs.get("total_target_templates", 100)) + self.min_group_score = float(kwargs.get("min_group_score", 0.15)) + self.fixed_num_groups = int(kwargs.get("fixed_num_groups", 10)) + self._logger.info( f"Loading template-based expansion policy model from {source} to {self.key}" ) @@ -313,14 +319,14 @@ def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: f"output dimensions of the model ({self.model.output_size})" ) self._cache: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} - self._cutoff_impl_name: str = str(kwargs.get("cutoff_impl", "templates")).lower() - self._cutoff_impl = getattr(self, f"_cutoff_predictions_{self._cutoff_impl_name}", None) - if self._cutoff_impl is None: - raise PolicyException(f"Unknown cutoff_impl='{self._cutoff_impl_name}'. Use 'templates' or 'groups'.") + # self._cutoff_impl_name: str = str(kwargs.get("cutoff_impl", "templates")).lower() + # self._cutoff_impl = getattr(self, f"_cutoff_predictions_{self._cutoff_impl_name}", None) + # if self._cutoff_impl is None: + # raise PolicyException(f"Unknown cutoff_impl='{self._cutoff_impl_name}'. Use 'templates' or 'groups'.") self._group_index = None - if self._cutoff_impl_name == "groups": - self._build_group_index() + # if self._cutoff_impl_name == "groups": + # self._build_group_index() def get_actions( self, @@ -439,7 +445,6 @@ def _dbg_dump_selection(self, preds: np.ndarray, selected: np.ndarray, tag: str) cum = float(np.nansum(preds[selected])) print(f"[DEBUG cutoff:{tag}] selected {len(selected)} templates " f"(mass {cum:.6g} of {total:.6g}); showing top {topn}") - # order by prob descending order = np.argsort(preds[selected])[::-1] for rank, rel in enumerate(order[:topn], start=1): pos = int(selected[rel]) @@ -447,88 +452,156 @@ def _dbg_dump_selection(self, preds: np.ndarray, selected: np.ndarray, tag: str) print(f" #{rank:>2} p={p:.6g} {self._dbg_row_info(pos)}") def _cutoff_predictions(self, predictions: np.ndarray) -> np.ndarray: - """ - Select up to TOTAL_TARGET templates: - If GROUPING: - 1) Rank groups by total probability; keep top NUM_GROUPS groups. - 2) From each kept group, take PER_GROUP templates - 3) Top up to TOTAL_TARGET with global best - Else: - Take global top TOTAL_TARGET templates. - """ - - # ---------------------------------- - GROUPING = True - NUM_GROUPS = 20 - PER_GROUP = 5 - TOTAL_TARGET = 100 - FILL_FROM_GLOBAL = True - # ---------------------------------- - preds = predictions.copy().astype(float) - if getattr(self, "mask", None) is not None: preds[~self.mask] = 0.0 all_idx = np.arange(preds.size, dtype=np.int32) global_order = all_idx[np.argsort(preds)[::-1]] - if not GROUPING: - return global_order[: min(TOTAL_TARGET, global_order.size)] + takeN = min(self.total_target_templates, global_order.size) + + if self.grouping_mode not in ("auto", "fixed"): + return global_order[:takeN] + if getattr(self, "_group_index", None) is None: self._build_group_index() gi = getattr(self, "_group_index", None) - - if not gi or NUM_GROUPS <= 0 or PER_GROUP <= 0: - return global_order[: min(TOTAL_TARGET, global_order.size)] + if not gi: + return global_order[:takeN] group_items = list(gi.items()) - group_sums = np.array( - [float(preds[np.asarray(members, dtype=np.int32)].sum()) for _, members in group_items], - dtype=float, - ) + group_members = [np.asarray(members, dtype=np.int32) for _, members in group_items] + group_scores = np.array([float(preds[m].sum()) for m in group_members], dtype=float) - # Order groups by total mass - g_order = np.argsort(group_sums)[::-1] + order = np.argsort(group_scores)[::-1] + pre = order[: min(50, order.size)] + + # print(f"[groups] top {len(pre)} group scores:") + # for i, gi_pos in enumerate(pre, start=1): + # print(f" #{i:>2} score={group_scores[gi_pos]:.6g}") + + if self.grouping_mode == "fixed": + keep = [int(gi_pos) for gi_pos in order[: self.fixed_num_groups]] + else: + threshold = self.min_group_score + keep = [int(gi_pos) for gi_pos in pre if group_scores[gi_pos] >= threshold] + + if len(keep) == 0: + return global_order[:takeN] selected: list[int] = [] + chosen = set() + k = max(0, int(self.group_per_group_k)) - # Take top NUM_GROUPS groups - for g_pos in g_order[: min(NUM_GROUPS, len(g_order))]: - _, members = group_items[g_pos] - members = np.asarray(members, dtype=np.int32) - if members.size == 0: + for gi_pos in keep: + members = group_members[gi_pos] + if members.size == 0 or k == 0: continue - - # Top PER_GROUP within this group local_sorted = members[np.argsort(preds[members])[::-1]] - take_k = int(min(PER_GROUP, local_sorted.size)) - - count = 0 - for idx in local_sorted: - if count >= take_k: + take_k = min(k, int(local_sorted.size)) + for x in local_sorted[:take_k]: + if len(selected) >= takeN: break - if preds[idx] > 0.0: - selected.append(int(idx)) - count += 1 - - if len(selected) >= TOTAL_TARGET: + xi = int(x) + if preds[xi] > 0.0 and xi not in chosen: + selected.append(xi) + chosen.add(xi) + if len(selected) >= takeN: break - # Top up from global templates if needed - if FILL_FROM_GLOBAL and len(selected) < TOTAL_TARGET: - chosen = set(selected) + if len(selected) < takeN: for idx in global_order: - if len(selected) >= TOTAL_TARGET: + if len(selected) >= takeN: break - if idx not in chosen and preds[idx] > 0.0: - selected.append(int(idx)) - - if not selected: - return global_order[: min(TOTAL_TARGET, global_order.size)] - if len(selected) > TOTAL_TARGET: - selected = selected[:TOTAL_TARGET] - - return np.asarray(selected, dtype=np.int32) + xi = int(idx) + if preds[xi] > 0.0 and xi not in chosen: + selected.append(xi) + + return np.asarray(selected[:takeN], dtype=np.int32) + + # def _cutoff_predictions(self, predictions: np.ndarray) -> np.ndarray: + # """ + # Select up to TOTAL_TARGET templates: + # If GROUPING: + # 1) Rank groups by total probability; keep top NUM_GROUPS groups. + # 2) From each kept group, take PER_GROUP templates + # 3) Top up to TOTAL_TARGET with global best + # Else: + # Take global top TOTAL_TARGET templates. + # """ + # + # # ---------------------------------- + # GROUPING = True + # NUM_GROUPS = 10 + # PER_GROUP = 1 + # TOTAL_TARGET = 100 + # FILL_FROM_GLOBAL = True + # # ---------------------------------- + # + # preds = predictions.copy().astype(float) + # + # if getattr(self, "mask", None) is not None: + # preds[~self.mask] = 0.0 + # + # all_idx = np.arange(preds.size, dtype=np.int32) + # global_order = all_idx[np.argsort(preds)[::-1]] + # if not GROUPING: + # return global_order[: min(TOTAL_TARGET, global_order.size)] + # if getattr(self, "_group_index", None) is None: + # self._build_group_index() + # gi = getattr(self, "_group_index", None) + # + # if not gi or NUM_GROUPS <= 0 or PER_GROUP <= 0: + # return global_order[: min(TOTAL_TARGET, global_order.size)] + # + # group_items = list(gi.items()) + # group_sums = np.array( + # [float(preds[np.asarray(members, dtype=np.int32)].sum()) for _, members in group_items], + # dtype=float, + # ) + # + # # Order groups by total mass + # g_order = np.argsort(group_sums)[::-1] + # + # selected: list[int] = [] + # + # # Take top NUM_GROUPS groups + # for g_pos in g_order[: min(NUM_GROUPS, len(g_order))]: + # _, members = group_items[g_pos] + # members = np.asarray(members, dtype=np.int32) + # if members.size == 0: + # continue + # + # # Top PER_GROUP within this group + # local_sorted = members[np.argsort(preds[members])[::-1]] + # take_k = int(min(PER_GROUP, local_sorted.size)) + # + # count = 0 + # for idx in local_sorted: + # if count >= take_k: + # break + # if preds[idx] > 0.0: + # selected.append(int(idx)) + # count += 1 + # + # if len(selected) >= TOTAL_TARGET: + # break + # + # # Top up from global templates if needed + # if FILL_FROM_GLOBAL and len(selected) < TOTAL_TARGET: + # chosen = set(selected) + # for idx in global_order: + # if len(selected) >= TOTAL_TARGET: + # break + # if idx not in chosen and preds[idx] > 0.0: + # selected.append(int(idx)) + # + # if not selected: + # return global_order[: min(TOTAL_TARGET, global_order.size)] + # if len(selected) > TOTAL_TARGET: + # selected = selected[:TOTAL_TARGET] + # + # return np.asarray(selected, dtype=np.int32) def _cutoff_predictions_templates(self, predictions: np.ndarray) -> np.ndarray: """ diff --git a/sources/retrosynthesis/routes.py b/sources/retrosynthesis/routes.py index 61faf30..4732c98 100644 --- a/sources/retrosynthesis/routes.py +++ b/sources/retrosynthesis/routes.py @@ -23,21 +23,29 @@ def service_check(): return jsonify(page_data) -@app.route("/retrosynthesis_api/", methods=["GET"]) +@app.route("/retrosynthesis_api/", methods=["GET", "POST"]) def retrosynthesis(): - access_key = str(request.args.get("key", "")) + if request.method == "POST": + payload = request.get_json(silent=True) or {} + access_key = str(payload.get("key", "")) + smiles = str(payload.get("smiles", "")).strip() + enhancement = str(payload.get("enhancement", "Default")) + iterations = int(payload.get("iterations", 100)) + max_depth = int(payload.get("max_depth", 7)) + time_limit = int(payload.get("time_limit", 60)) + else: + access_key = str(request.args.get("key", "")) + smiles = str(request.args.get("smiles", "")).strip().replace(" ", "+") + enhancement = str(request.args.get("enhancement", "Default")) + iterations = int(request.args.get("iterations", 100)) + max_depth = int(request.args.get("max_depth", 7)) + time_limit = int(request.args.get("time_limit", 60)) + if access_key != ACCESS_KEY: return jsonify({"Message": "Invalid key", "Timestamp": time.time()}), 401 - - smiles = str(request.args.get("smiles", "")).strip() if not smiles: return jsonify({"Message": "Missing 'smiles'", "Timestamp": time.time()}), 400 - enhancement = str(request.args.get("enhancement", "Default")) - iterations = int(request.args.get("iterations", 100)) - max_depth = int(request.args.get("max_depth", 7)) - time_limit = int(request.args.get("time_limit", 60)) - finder = sources.retrosynthesis.startup.make_config() finder.config.search.algorithm_config["enhancement"] = enhancement finder.config.search.iteration_limit = iterations @@ -45,6 +53,7 @@ def retrosynthesis(): finder.config.search.time_limit = time_limit solved_route_dict_ls, raw_routes = retrosynthesis_process(smiles, finder) + page_data = { "Message": solved_route_dict_ls, "Raw_Routes": raw_routes, @@ -53,6 +62,7 @@ def retrosynthesis(): return jsonify(page_data) + def retrosynthesis_process(smiles, finder): """ Takes a SMILES string and a pre-configured finder object and returns @@ -69,14 +79,12 @@ def retrosynthesis_process(smiles, finder): raise ValueError("Invalid SMILES string") aizynthcli._process_single_smiles(smiles, finder, None, False, None, [], None) - routes = finder.routes solved_routes = [] for idx, node in enumerate(routes.nodes): if node.is_solved is True: solved_routes.append(routes[idx]) - solved_routes = solved_routes[0:10] solved_route_dict = {} for idx, route in enumerate(solved_routes, 1): retro_route = RetroRoute(route["dict"]) @@ -88,11 +96,38 @@ def retrosynthesis_process(smiles, finder): } solved_route_dict[f"Route {idx}"] = route_dic - route_dicts = routes.dicts[0:10] + route_dicts = routes.dicts raw_routes = [route_dict for route_dict in route_dicts] return solved_route_dict, raw_routes + +def _min_tree(n, is_mol_root=False): + """Minimize a route tree to the exact shape the converter expects. + - mol: keep type, smiles, and (first) child + - reaction: keep type and children + """ + if not isinstance(n, dict): + return None + t = n.get("type") + if t == "mol": + out = {"type": "mol", "smiles": n.get("smiles", "")} + children = n.get("children", []) + if children: + c0 = _min_tree(children[0]) + if c0 is not None: + out["children"] = [c0] + return out + if t == "reaction": + kids = [] + for c in n.get("children", []): + k = _min_tree(c) + if k is not None: + kids.append(k) + return {"type": "reaction", "children": kids} + return None + + @app.route("/retrosynthesis_batch_progress/", methods=["GET"]) def retrosynthesis_batch_progress(batch_id): info = BATCH_PROGRESS.get(batch_id) @@ -145,6 +180,20 @@ def retrosynthesis_batch(): BATCH_PROGRESS[batch_id].update({"status": "failed"}) raise + min_trees = [] + for item in results: + for r in item.get("raw_routes", []): + t = _min_tree(r) + if t and t.get("type") == "mol" and t.get("smiles"): + min_trees.append(t) + + BATCH_OUT_DIR.mkdir(parents=True, exist_ok=True) + out_name = BATCH_OUT_DIR / "compact_routes.json" + tmp_path = out_name.with_suffix(".json.tmp") + with tmp_path.open("w", encoding="utf-8") as f: + json.dump(min_trees, f, ensure_ascii=False, indent=2) + tmp_path.replace(out_name) + batch_entries = [] for item in results: batch_entries.append({ diff --git a/sources/retrosynthesis/startup.py b/sources/retrosynthesis/startup.py index f06d6fa..ef261f5 100644 --- a/sources/retrosynthesis/startup.py +++ b/sources/retrosynthesis/startup.py @@ -24,7 +24,7 @@ def make_config(): 'full': { 'type': 'template-based', 'model': os.path.join(BASEDIR, 'config_files', 'uspto_model.onnx'), - 'template': os.path.join(BASEDIR, 'config_files', 'uspto_templates.csv.gz') + 'template': os.path.join(BASEDIR, 'config_files', 'uspto_templates_grouped.csv.gz') } }, From 1ba6bc6d127df765233e46c44934c7cbd0ec131d Mon Sep 17 00:00:00 2001 From: tonbl Date: Wed, 22 Oct 2025 14:38:12 +0100 Subject: [PATCH 3/3] added grouping controls --- sources/retrosynthesis/startup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sources/retrosynthesis/startup.py b/sources/retrosynthesis/startup.py index ef261f5..e294943 100644 --- a/sources/retrosynthesis/startup.py +++ b/sources/retrosynthesis/startup.py @@ -30,7 +30,7 @@ def make_config(): 'stock': { 'zinc': os.path.join(BASEDIR, 'config_files', 'zinc_stock_17_04_20.hdf5'), - 'overlay': os.path.join(BASEDIR, 'config_files', 'stock_additions_2025_08.hdf5'), + #'overlay': os.path.join(BASEDIR, 'config_files', 'stock_additions_2025_08.hdf5'), 'alcohols': os.path.join(BASEDIR, 'config_files', 'dummy_alcohols.hdf5'), 'naturals': os.path.join(BASEDIR, 'config_files', 'np_08_25.hdf5'), 'non_iso_naturals': os.path.join(BASEDIR, 'config_files', 'ninp_08_25.hdf5'),