From 7971bae2995f2e2cae6126a92b26dc9f173f4748 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Tue, 28 Oct 2025 16:03:42 +0100 Subject: [PATCH 1/2] Remove annotation hints from holistic extractor prompt --- README.md | 19 +- src/molmole_research/downloader.py | 429 +++++++++++++------------ src/molmole_research/evaluator.py | 490 ++++++++++++++++++++--------- src/molmole_research/extractor.py | 210 ++++++++++--- tests/test_downloader.py | 125 +++----- tests/test_evaluator.py | 98 ++++-- tests/test_extractor.py | 82 +++-- 7 files changed, 910 insertions(+), 543 deletions(-) diff --git a/README.md b/README.md index e211c81..44ae9c2 100644 --- a/README.md +++ b/README.md @@ -143,18 +143,21 @@ print clear instructions for manual download. --dataset-dir data/images --out results ``` - This will create a JSONL file in `results/` whose name begins with the - model name and includes the current timestamp. Each line contains the - image file name, the extracted text and any additional metadata returned by - the model. The extractor includes a template prompt that instructs the - model to convert the chemical diagram into a SMILES string. You may + This command now performs **holistic page analysis**. Each JSONL line + corresponds to a single patent page and stores the prompt, raw model + response and any parsed JSON payload containing `structures` (with + model-assigned identifiers and SMILES strings) and `reactions` + (reactant/product identifier sets plus optional conditions). The default + prompt keeps the model blind to MolMole annotations so it must discover the + number of structures and reactions directly from the image. You may customise the prompt or provide an alternative generation function by editing `extractor.py`. 5. **Evaluate predictions**: use the evaluator script to compare your - predictions against the ground truth SMILES strings. The evaluator - canonicalises SMILES using RDKit (if available) and computes accuracy - using both SMILES and InChI‑Key matching: + predictions against the ground truth SMILES strings and reaction graphs. + The evaluator canonicalises SMILES using RDKit (if available), matches + predictions per structure identifier, and reports SMILES/InChI accuracy, + mean Tanimoto similarity and reaction precision/recall/F1: ```bash python -m molmole_research.evaluator --pred results/gpt-4o-vision_*.jsonl \ diff --git a/src/molmole_research/downloader.py b/src/molmole_research/downloader.py index 80209d4..0190f28 100644 --- a/src/molmole_research/downloader.py +++ b/src/molmole_research/downloader.py @@ -1,51 +1,32 @@ -"""Dataset downloader for the MolMole benchmark. - -This module provides a CLI for downloading the MolMole validation dataset (or -any HuggingFace dataset compatible with the `datasets` library) into a local -directory. The MolMole dataset released by LG AI Research is hosted on -HuggingFace under the name `doxa-friend/MolMole_Patent300`. Each entry -contains a page‑level PNG image and associated annotations (SMILES strings, -MOL files, bounding boxes, etc.) depending on the subset. - -The downloader saves all images into an `images/` subdirectory and writes a -`labels.json` file containing the ground truth SMILES (if available). +"""Download the MolMole dataset and derive canonical SMILES labels. + +The MolMole benchmark is distributed on HuggingFace as a dataset repository +containing page level images, JSON annotations and MOL files for each detected +structure. This module provides a CLI that downloads the repository snapshot +into a user specified directory, preserving the original file hierarchy. After +the download completes, the MOL files are converted into canonical SMILES using +RDKit and collated into a ``labels.json`` file that is easier to consume during +evaluation. """ from __future__ import annotations import json -import os -import shutil from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple -import datasets -from PIL import Image import typer from huggingface_hub import snapshot_download +from tqdm import tqdm -try: +try: # pragma: no cover - optional dependency at import time from rdkit import Chem # type: ignore[assignment] -except ImportError: # pragma: no cover - optional dependency +except ImportError: # pragma: no cover - handled gracefully at runtime Chem = None # type: ignore[assignment] app = typer.Typer(add_completion=False, help="Download the MolMole dataset") -def _extract_smiles(record: Dict[str, Any]) -> Optional[str]: - """Try to extract a SMILES string from a dataset record. - - The MolMole dataset stores SMILES strings under different keys depending on - the subset. This function attempts several common field names and returns - the first non‑empty value, or ``None`` if no SMILES field exists. - """ - for key in ("smiles", "SMILES", "label", "labels", "text"): - value = record.get(key) - if isinstance(value, str) and value.strip(): - return value.strip() - return None - - def _unique_preserve_order(values: Iterable[str]) -> List[str]: """Return a list with duplicates removed while preserving order.""" @@ -58,11 +39,18 @@ def _unique_preserve_order(values: Iterable[str]) -> List[str]: return result +def _ensure_rdkit() -> None: + """Raise a runtime error if RDKit is not available.""" + + if Chem is None: # pragma: no cover - depends on runtime installation + raise RuntimeError( + "RDKit is required to convert MOL files to SMILES. Install the 'rdkit' package before running the downloader." + ) + + def _mol_to_smiles(mol_path: Path) -> Optional[str]: - """Convert a MOL file into a SMILES string using RDKit if available.""" + """Convert a MOL file into a canonical SMILES string.""" - if Chem is None: # pragma: no cover - RDKit optional at runtime - return None try: mol = Chem.MolFromMolFile(str(mol_path)) # type: ignore[arg-type] if mol is None: @@ -72,131 +60,176 @@ def _mol_to_smiles(mol_path: Path) -> Optional[str]: return None -def _integrate_additional_metadata(dataset: str, out: Path, labels: List[Dict[str, Any]]) -> None: - """Augment labels with MolMole metadata if available. - - The MolMole dataset repository contains additional JSON annotations with - bounding boxes and corresponding MOL files for each chemical structure. - This function downloads those assets (if present) and enriches the - ``labels`` list with bounding boxes, SMILES strings derived from the MOL - files, and copies the raw assets into ``out``. - """ +def _discover_images_dir(dataset_root: Path, preferred: Optional[str]) -> Path: + """Return the directory containing page level images.""" - try: - repo_root = Path( - snapshot_download( - dataset, - repo_type="dataset", - allow_patterns=["json/*", "mol/*"], + if preferred: + candidate = dataset_root / preferred + if not candidate.exists(): + raise FileNotFoundError( + f"Requested images directory '{preferred}' was not found under {dataset_root}." ) - ) - except Exception: # pragma: no cover - network/cache issues handled gracefully - return + return candidate - json_dir = repo_root / "json" - if not json_dir.exists(): - return - json_files = sorted(json_dir.glob("*.json")) - if not json_files: - return + for child in sorted(dataset_root.iterdir()): + if child.is_dir() and child.name.lower().startswith("images"): + return child + raise FileNotFoundError( + f"Could not locate an images directory inside {dataset_root}. Expected a folder such as 'images_300'." + ) - annotations_path = json_files[0] - try: - annotations = json.loads(annotations_path.read_text(encoding="utf-8")) - except Exception: - return - - images_meta = { - item.get("file_name"): item - for item in annotations.get("images", []) - if isinstance(item, dict) and item.get("file_name") + +def _discover_annotation_file(dataset_root: Path, filename: Optional[str]) -> Path: + """Locate the annotation JSON shipped with the dataset.""" + + search_dirs = [dataset_root / "json", dataset_root] + for directory in search_dirs: + if not directory.exists(): + continue + if filename: + candidate = directory / filename + if candidate.exists(): + return candidate + else: + json_files = sorted(directory.glob("*.json")) + if json_files: + return json_files[0] + + location = dataset_root / "json" / (filename or "*.json") + raise FileNotFoundError(f"Unable to locate annotation file at {location}") + + +def _download_snapshot(dataset: str, out: Path, revision: Optional[str]) -> Path: + """Download the dataset snapshot into ``out`` and return the local path.""" + + typer.echo(f"Downloading dataset {dataset} to {out} …") + kwargs: Dict[str, Any] = { + "repo_type": "dataset", + "local_dir": str(out), + } + if revision: + kwargs["revision"] = revision + snapshot_path = Path(snapshot_download(dataset, **kwargs)) + typer.echo(f"Snapshot available at {snapshot_path}") + return snapshot_path + + +def _build_labels( + dataset_root: Path, + images_dir: Path, + annotation_file: Path, + mol_dir: Path, +) -> Tuple[List[Dict[str, Any]], Dict[str, int]]: + """Create structured labels enriched with canonical SMILES.""" + + data = json.loads(annotation_file.read_text(encoding="utf-8")) + records: List[Dict[str, Any]] = [] + + category_names: Dict[int, str] = { + category.get("id"): category.get("name", "unknown") + for category in data.get("categories", []) + if isinstance(category, dict) } - mol_repo_dir = repo_root / "mol" - mol_out_dir = out / "mol" - copied_mols: Set[str] = set() - structures_added = 0 + total_structures = 0 + smiles_converted = 0 + conversion_failures = 0 + missing_mols = 0 + reactions_total = 0 - for entry in labels: - fname = entry.get("file_name") - if not fname: + try: + image_prefix = images_dir.relative_to(dataset_root) + except ValueError: # pragma: no cover - defensive fallback + image_prefix = images_dir + + for image_meta in tqdm(data.get("images", []), desc="Processing annotations", unit="image"): + if not isinstance(image_meta, dict): continue - meta = images_meta.get(fname) - if not meta: + file_name = image_meta.get("file_name") + if not file_name: continue - entry.setdefault("height", meta.get("height")) - entry.setdefault("width", meta.get("width")) + entry: Dict[str, Any] = { + "file_name": str(image_prefix / file_name), + "width": image_meta.get("width"), + "height": image_meta.get("height"), + } + if image_meta.get("id") is not None: + entry["page_id"] = image_meta["id"] - structures: List[Dict[str, Any]] = entry.get("structures", []) - texts: List[Dict[str, Any]] = entry.get("texts", []) + layout_boxes: List[Dict[str, Any]] = [] + structures: List[Dict[str, Any]] = [] + smiles_candidates: List[str] = [] - for bbox in meta.get("dla_bboxes", []): + for bbox in image_meta.get("dla_bboxes", []): if not isinstance(bbox, dict): continue - category_id = bbox.get("category_id") + bbox_id = bbox.get("id") - bbox_coords = bbox.get("bbox") - if category_id == 1 and bbox_id is not None: - mol_name = f"{Path(fname).stem}_bbox_{bbox_id}.mol" - mol_repo_path = mol_repo_dir / mol_name - smiles = None - mol_relative = None - if mol_repo_path.exists(): - mol_out_dir.mkdir(parents=True, exist_ok=True) - if mol_name not in copied_mols: - shutil.copy2(mol_repo_path, mol_out_dir / mol_name) - copied_mols.add(mol_name) - smiles = _mol_to_smiles(mol_out_dir / mol_name) - mol_relative = str(Path("mol") / mol_name) - structures.append( - { - "id": bbox_id, - "bbox": bbox_coords, - "category": "structure", - "mol_file": mol_relative, - "smiles": smiles, - } - ) - structures_added += 1 - if smiles: - smiles_list = entry.get("smiles_list", []) - if isinstance(smiles_list, list): - smiles_list.append(smiles) - entry["smiles_list"] = smiles_list + if bbox_id is None: + continue + + category_id = bbox.get("category_id") + category_name = category_names.get(category_id, "unknown") + + record: Dict[str, Any] = { + "id": bbox_id, + "bbox": bbox.get("bbox"), + "category_id": category_id, + "category_name": category_name, + } + + if category_id == 1: + total_structures += 1 + mol_name = f"{Path(file_name).stem}_bbox_{bbox_id}.mol" + mol_path = mol_dir / mol_name + smiles: Optional[str] = None + mol_relative: Optional[str] = None + + if mol_path.exists(): + mol_relative = str(mol_path.relative_to(dataset_root)) + smiles = _mol_to_smiles(mol_path) + if smiles: + smiles_converted += 1 + smiles_candidates.append(smiles) else: - entry["smiles_list"] = [smiles] - elif category_id == 2: - texts.append( - { - "id": bbox_id, - "bbox": bbox_coords, - "category": "text", - } - ) + conversion_failures += 1 + else: + missing_mols += 1 - if structures: - entry["structures"] = structures - if texts: - entry["texts"] = texts + record.update({ + "mol_file": mol_relative, + "smiles": smiles, + }) + structures.append(dict(record)) - smiles_list = entry.get("smiles_list") or [] - if isinstance(smiles_list, list): - entry["smiles_list"] = _unique_preserve_order( - [s for s in smiles_list if isinstance(s, str) and s.strip()] - ) - if entry["smiles_list"] and not entry.get("smiles"): - entry["smiles"] = entry["smiles_list"][0] + layout_boxes.append(record) - # Copy the annotations JSON for reference - annotations_out = out / "annotations.json" - try: - shutil.copy2(annotations_path, annotations_out) - except Exception: # pragma: no cover - copying issues should not fail download - pass + if layout_boxes: + entry["layout_boxes"] = layout_boxes + if structures: + entry["structures"] = structures + if smiles_candidates: + entry["smiles_list"] = _unique_preserve_order(smiles_candidates) + entry["smiles"] = entry["smiles_list"][0] + + reactions = image_meta.get("reactions") + if isinstance(reactions, list) and reactions: + entry["reactions"] = reactions + reactions_total += len(reactions) + + records.append(entry) + + stats = { + "images": len(records), + "structures_total": total_structures, + "structures_with_smiles": smiles_converted, + "conversion_failures": conversion_failures, + "missing_mol_files": missing_mols, + "reactions_total": reactions_total, + } - if structures_added: - typer.echo(f"Integrated {structures_added} structures from repository metadata") + return records, stats @app.command("run") @@ -204,76 +237,62 @@ def download_dataset( dataset: str = typer.Option( "doxa-friend/MolMole_Patent300", help="HuggingFace dataset identifier" ), - split: str = typer.Option("train", help="Which split to download (e.g., train/validation)"), - out: Path = typer.Option(Path("data/images"), help="Output directory for images and labels"), + out: Path = typer.Option( + Path("data/molmole"), + help="Directory where the dataset snapshot and derived labels will be stored", + ), + images_subdir: Optional[str] = typer.Option( + None, + help="Name of the images folder inside the dataset (defaults to auto-detect)", + ), + annotations_file: Optional[str] = typer.Option( + None, + help="JSON file containing annotations (defaults to the first JSON in the json/ directory)", + ), + revision: Optional[str] = typer.Option( + None, + help="Optional dataset revision to download", + ), ) -> None: - """Download the specified dataset and save it locally. + """Download MolMole assets and generate canonical SMILES labels.""" + + try: + _ensure_rdkit() + except RuntimeError as exc: + typer.echo(str(exc)) + raise typer.Exit(1) - This function loads the dataset via ``datasets.load_dataset``, iterates over - each entry, saves the image to disk and collects the ground truth SMILES - string (if present). The images are stored in an ``images/`` subfolder of - ``out`` and labels are written to ``labels.json``. - """ out = Path(out).resolve() - images_dir = out / "images" - images_dir.mkdir(parents=True, exist_ok=True) + out.mkdir(parents=True, exist_ok=True) - typer.echo(f"Loading dataset {dataset} (split={split}) …") - try: - ds = datasets.load_dataset(dataset, split=split) - except Exception as exc: # pragma: no cover - network errors - typer.echo( - f"Failed to download dataset {dataset}. Please ensure you have access and" - f" have run `huggingface-cli login` if required. Error: {exc}" - ) + dataset_root = _download_snapshot(dataset, out, revision) + + images_dir = _discover_images_dir(dataset_root, images_subdir) + mol_dir = dataset_root / "mol" + if not mol_dir.exists(): + typer.echo(f"The dataset snapshot in {dataset_root} does not contain a 'mol' directory.") raise typer.Exit(1) - labels: List[Dict[str, Any]] = [] - for idx, record in enumerate(ds): - # Retrieve the image - image: Image.Image - image = record.get("image") # Datasets library returns PIL.Image objects - if not isinstance(image, Image.Image): - # Some datasets return file paths instead; load with PIL - image_path = record.get("path") or record.get("file_path") or record.get("file_name") - if image_path and os.path.exists(image_path): - image = Image.open(image_path) - else: - typer.echo(f"Record {idx} does not contain an image; skipping") - continue + annotation_path = _discover_annotation_file(dataset_root, annotations_file) + + typer.echo( + f"Building labels from {annotation_path.relative_to(dataset_root)} and MOL files in {mol_dir.relative_to(dataset_root)}" + ) + labels, stats = _build_labels(dataset_root, images_dir, annotation_path, mol_dir) + + labels_path = dataset_root / "labels.json" + labels_path.write_text(json.dumps(labels, indent=2), encoding="utf-8") + + typer.echo(f"Wrote labels with canonical SMILES to {labels_path}") + typer.echo( + "Summary: " + f"{stats['images']} images, {stats['structures_total']} structures, " + f"{stats['structures_with_smiles']} successfully converted, " + f"{stats['conversion_failures']} conversion failures, " + f"{stats['missing_mol_files']} missing MOL files, " + f"{stats['reactions_total']} reactions annotated." + ) + - # Determine a file name - fname = record.get("file_name") or record.get("image_id") or record.get("id") - if not fname: - fname = f"{idx}.png" - # Ensure the file has a png extension - if not str(fname).lower().endswith((".png", ".jpg", ".jpeg")): - fname = f"{fname}.png" - out_path = images_dir / fname - # Save image - try: - image.save(out_path) - except Exception: - # Ensure directory exists and convert to RGB - image = image.convert("RGB") - image.save(out_path) - - # Extract SMILES (if available) - smiles = _extract_smiles(record) - entry: Dict[str, Any] = {"file_name": fname, "smiles": smiles} - if smiles: - entry["smiles_list"] = [smiles] - labels.append(entry) - - # Enrich labels with additional metadata (MolMole repository only) - _integrate_additional_metadata(dataset, out, labels) - - # Save labels JSON - labels_path = out / "labels.json" - with labels_path.open("w", encoding="utf-8") as fh: - json.dump(labels, fh, indent=2) - typer.echo(f"Downloaded {len(labels)} items to {images_dir} and wrote labels.json") - - -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover - CLI entry point app() diff --git a/src/molmole_research/evaluator.py b/src/molmole_research/evaluator.py index 5eea592..b8dbb66 100644 --- a/src/molmole_research/evaluator.py +++ b/src/molmole_research/evaluator.py @@ -1,96 +1,301 @@ -"""Evaluation of OCSR predictions on the MolMole dataset. +"""Evaluation of MolMole holistic page predictions. -This module compares predicted SMILES strings against ground truth SMILES and -computes simple accuracy metrics. Two forms of matching are provided: +The extractor now emits one JSONL record per patent page containing the raw +model response and (when possible) a parsed JSON payload with ``structures`` +and ``reactions``. This evaluator reproduces the molecule conversion metrics +from the MolMole paper while additionally reporting reaction graph precision, +recall and F1. -* **SMILES canonical equality** – predictions and ground truth are converted - into canonical SMILES using RDKit and compared. -* **InChI‑Key equality** – predictions and ground truth are converted into - InChI keys via RDKit and compared. +Metrics include: -The evaluator reads a JSONL file produced by the extractor, loads the -corresponding ground truth SMILES from ``labels.json``, and writes a JSON -summary of accuracy metrics. +* **SMILES accuracy** – fraction of structures whose predicted SMILES equals + the ground truth canonical SMILES. +* **InChI accuracy** – fraction of structures with matching InChI keys. +* **Tanimoto similarity** – mean similarity between predicted and ground truth + Morgan fingerprints (radius 2, 2048 bits). +* **Reaction precision/recall/F1** – overlap between predicted and annotated + reaction edges (reactant/product identifier sets plus optional conditions). """ from __future__ import annotations import json +import re from pathlib import Path -from typing import Dict, Iterable, List, Set +from statistics import mean +from typing import Dict, Iterable, List, Optional, Set, Tuple import typer try: - # RDKit is used for canonical SMILES and InChI generation. It may not be - # available in all environments, so we make its import optional and - # gracefully degrade to string comparison if absent. + # RDKit is required for MolMole-style evaluation. from rdkit import Chem # type: ignore[assignment] -except ImportError: + from rdkit.Chem import AllChem, DataStructs # type: ignore[attr-defined] +except ImportError: # pragma: no cover - handled at runtime Chem = None # type: ignore[assignment] + AllChem = None # type: ignore[assignment] + DataStructs = None # type: ignore[assignment] + from tqdm import tqdm app = typer.Typer(add_completion=False, help="Evaluate OCSR predictions") -def _canonical_smiles(smiles: str) -> str | None: - """ - Convert a SMILES string to its canonical form using RDKit if available. - - If RDKit is not installed, the function returns the input string stripped - of whitespace. If conversion fails, ``None`` is returned. - """ - # If RDKit is available, use it for canonicalisation - if Chem is not None: - try: - mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] - if mol is None: - return None - return Chem.MolToSmiles(mol, canonical=True) # type: ignore[call-arg] - except Exception: +def _canonical_smiles(smiles: str) -> Optional[str]: + """Return the canonical SMILES for ``smiles`` or ``None`` if invalid.""" + + if Chem is None: + stripped = smiles.strip() + return stripped or None + try: + mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] + if mol is None: return None - # Fallback: return the raw string for naive comparison - return smiles.strip() if smiles.strip() else None - - -def _inchi_key(smiles: str) -> str | None: - """ - Return the InChI‑Key for a SMILES string using RDKit if available. - - If RDKit is not installed, the function falls back to returning the - canonical SMILES string for equality checking. If conversion fails, - ``None`` is returned. - """ - if Chem is not None: - try: - mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] - if mol is None: - return None - return Chem.inchi.MolToInchiKey(mol) # type: ignore[call-arg] - except Exception: + return Chem.MolToSmiles(mol, canonical=True) # type: ignore[call-arg] + except Exception: + return None + + +def _inchi_key(smiles: str) -> Optional[str]: + """Return the InChI key for ``smiles`` if possible.""" + + if Chem is None: + stripped = smiles.strip() + return stripped or None + try: + mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] + if mol is None: return None - # Fallback: use canonical SMILES as a proxy for InChI equality - canonical = smiles.strip() if smiles.strip() else None - return canonical + return Chem.inchi.MolToInchiKey(mol) # type: ignore[call-arg] + except Exception: + return None + + +def _fingerprint(smiles: str): + """Compute a Morgan fingerprint for ``smiles`` or return ``None``.""" + + if Chem is None or AllChem is None: # pragma: no cover - RDKit guard + return None + try: + mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] + if mol is None: + return None + return AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) + except Exception: + return None + + +def _tanimoto(fp_a, fp_b) -> Optional[float]: # type: ignore[override] + if DataStructs is None or fp_a is None or fp_b is None: # pragma: no cover + return None + return float(DataStructs.TanimotoSimilarity(fp_a, fp_b)) + + +def _normalise_prediction(text: str) -> Optional[str]: + """Extract the first valid SMILES token from ``text``.""" + + cleaned = text.strip() + if not cleaned: + return None + + for token in re.split(r"[\s,;]+", cleaned): + if not token: + continue + canonical = _canonical_smiles(token) + if canonical: + return canonical + return _canonical_smiles(cleaned) + + +def _load_labels(labels_path: Path) -> List[Dict[str, object]]: + data = json.loads(labels_path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError("labels.json must contain a list of page records") + return data + + +def _index_structures( + labels: Iterable[Dict[str, object]] +) -> Dict[Tuple[str, int], Dict[str, object]]: + mapping: Dict[Tuple[str, int], Dict[str, object]] = {} + for page in labels: + if not isinstance(page, dict): + continue + file_name = page.get("file_name") + if not isinstance(file_name, str) or not file_name: + continue + structures = page.get("structures") + if not isinstance(structures, list): + continue + for struct in structures: + if not isinstance(struct, dict): + continue + struct_id = struct.get("id") + smiles = struct.get("smiles") + if struct_id is None or not isinstance(smiles, str) or not smiles.strip(): + continue + canonical = _canonical_smiles(smiles) + if not canonical: + continue + mapping[(file_name, int(struct_id))] = { + "canonical": canonical, + "inchi": _inchi_key(canonical), + "fingerprint": _fingerprint(canonical), + } + return mapping + + +def _index_reactions(labels: Iterable[Dict[str, object]]) -> Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]]: + """Collect ground truth reactions per page.""" + + reaction_map: Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]] = {} + for page in labels: + if not isinstance(page, dict): + continue + file_name = page.get("file_name") + if not isinstance(file_name, str) or not file_name: + continue + reactions = page.get("reactions") + if not isinstance(reactions, list): + continue + normalised: Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]] = set() + for reaction in reactions: + normalised_entry = _normalise_reaction(reaction) if isinstance(reaction, dict) else None + if normalised_entry: + normalised.add(normalised_entry) + if normalised: + reaction_map[file_name] = normalised + return reaction_map + + +def _normalise_reaction(entry: Dict[str, object]) -> Optional[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]: + """Convert a reaction dict into a hashable tuple.""" + + if not isinstance(entry, dict): + return None + + def _normalise_ids(values) -> Tuple[int, ...]: + ids: List[int] = [] + if isinstance(values, list): + for item in values: + try: + ids.append(int(item)) + except (TypeError, ValueError): + continue + return tuple(sorted(set(ids))) + + reactants = _normalise_ids(entry.get("reactants")) + products = _normalise_ids(entry.get("products")) + if not reactants or not products: + return None + + conditions: List[str] = [] + if isinstance(entry.get("conditions"), list): + for cond in entry["conditions"]: + if isinstance(cond, str): + stripped = cond.strip() + if stripped: + conditions.append(stripped) + return reactants, products, tuple(sorted(set(conditions))) + + +def _load_predictions( + pred_path: Path, +) -> Tuple[ + Dict[Tuple[str, int], Dict[str, object]], + Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]], +]: + """Load structure and reaction predictions from the extractor output.""" + + predictions: Dict[Tuple[str, int], Dict[str, object]] = {} + reaction_predictions: Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]] = {} + + with pred_path.open("r", encoding="utf-8") as fh: + for line in tqdm(fh, desc="Evaluating", unit="page"): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + + file_name = record.get("file_name") + if not isinstance(file_name, str) or not file_name: + continue + + structures = record.get("predicted_structures") + if isinstance(structures, list): + for struct in structures: + if not isinstance(struct, dict): + continue + struct_id = struct.get("id") + smiles_value = struct.get("smiles") + if smiles_value is None: + smiles_value = struct.get("text") + if struct_id is None or not isinstance(smiles_value, str): + continue + try: + struct_id_int = int(struct_id) + except (TypeError, ValueError): + continue + canonical = _normalise_prediction(smiles_value) + predictions[(file_name, struct_id_int)] = { + "canonical": canonical, + "inchi": _inchi_key(canonical) if canonical else None, + "fingerprint": _fingerprint(canonical) if canonical else None, + } + else: + # Backwards compatibility with per-structure JSONL records. + struct_id = record.get("structure_id") + text = record.get("text") + if struct_id is not None and isinstance(text, str): + try: + struct_id_int = int(struct_id) + except (TypeError, ValueError): + struct_id_int = None + if struct_id_int is not None: + canonical = _normalise_prediction(text) + predictions[(file_name, struct_id_int)] = { + "canonical": canonical, + "inchi": _inchi_key(canonical) if canonical else None, + "fingerprint": _fingerprint(canonical) if canonical else None, + } + + reactions = record.get("predicted_reactions") + if isinstance(reactions, list): + normalised = { + r + for r in (_normalise_reaction(item) for item in reactions) + if r is not None + } + if normalised: + reaction_predictions[file_name] = normalised + + return predictions, reaction_predictions + + +def _mean(values: Iterable[float]) -> float: + collected = list(values) + if not collected: + return 0.0 + return float(mean(collected)) @app.command("run") def evaluate( pred: Path = typer.Option(..., exists=True, help="JSONL predictions from the extractor"), dataset_dir: Path = typer.Option( - Path("data/images"), exists=True, help="Directory containing labels.json" + Path("data/molmole"), exists=True, help="Directory containing labels.json" ), out: Path = typer.Option(Path("results"), help="Directory where metrics will be saved"), ) -> None: - """Evaluate predictions against ground truth SMILES. - - The evaluator expects a ``labels.json`` file in ``dataset_dir`` containing a - list of objects with keys ``file_name`` and ``smiles``. The predictions - file must be a JSONL where each line has at least ``file_name`` and - ``text`` fields. For each common file name, the predicted text is - canonicalised and compared to the ground truth using both canonical SMILES - and InChI keys. - """ + """Evaluate predictions using MolMole conversion metrics.""" + + if Chem is None or AllChem is None or DataStructs is None: + typer.echo("RDKit with InChI support is required for evaluation") + raise typer.Exit(1) + dataset_dir = dataset_dir.resolve() out = out.resolve() out.mkdir(parents=True, exist_ok=True) @@ -99,97 +304,92 @@ def evaluate( if not labels_path.exists(): typer.echo(f"Missing labels.json in {dataset_dir}; run the downloader first") raise typer.Exit(1) - with labels_path.open("r", encoding="utf-8") as fh: - labels_list: List[Dict[str, object]] = json.load(fh) - - def _collect_smiles(entry: Dict[str, object]) -> List[str]: - values: List[str] = [] - raw_smiles = entry.get("smiles") - if isinstance(raw_smiles, str) and raw_smiles.strip(): - values.append(raw_smiles.strip()) - smiles_list = entry.get("smiles_list") - if isinstance(smiles_list, Iterable) and not isinstance(smiles_list, (str, bytes)): - values.extend(s.strip() for s in smiles_list if isinstance(s, str) and s.strip()) - for struct in ( - entry.get("structures", []) if isinstance(entry.get("structures"), list) else [] - ): - if isinstance(struct, dict): - struct_smiles = struct.get("smiles") - if isinstance(struct_smiles, str) and struct_smiles.strip(): - values.append(struct_smiles.strip()) - # Deduplicate while preserving order - seen: Set[str] = set() - unique: List[str] = [] - for value in values: - if value not in seen: - seen.add(value) - unique.append(value) - return unique - - labels: Dict[str, Dict[str, Set[str]]] = {} - for entry in labels_list: - fname = entry.get("file_name") if isinstance(entry, dict) else None - if not isinstance(fname, str) or not fname: - continue - smiles_values = _collect_smiles(entry) - if not smiles_values: - continue - canonical_set: Set[str] = set() - inchi_set: Set[str] = set() - for smiles in smiles_values: - canonical = _canonical_smiles(smiles) - if canonical: - canonical_set.add(canonical) - inchi = _inchi_key(smiles) - if inchi: - inchi_set.add(inchi) - if canonical_set or inchi_set: - labels[fname] = {"canonical": canonical_set, "inchi": inchi_set} - - total = 0 + + labels = _load_labels(labels_path) + structures = _index_structures(labels) + reaction_truth = _index_reactions(labels) + if not structures: + typer.echo("No evaluable structures found in labels.json") + raise typer.Exit(1) + + predictions, reaction_predictions = _load_predictions(pred) + + total = len(structures) correct_smiles = 0 correct_inchi = 0 - # Evaluate each prediction - with pred.open("r", encoding="utf-8") as fh: - for line in tqdm(fh, desc="Evaluating", unit="pred"): - line = line.strip() - if not line: - continue - try: - record = json.loads(line) - except json.JSONDecodeError: - continue - fname = record.get("file_name") - pred_text = record.get("text") - if not fname or pred_text is None: - continue - gt_sets = labels.get(fname) - if not gt_sets: - continue - total += 1 - # Canonicalise prediction - pred_canonical = _canonical_smiles(pred_text.strip()) if pred_text.strip() else None - pred_inchi = _inchi_key(pred_text.strip()) if pred_text.strip() else None - if pred_canonical and gt_sets["canonical"] and pred_canonical in gt_sets["canonical"]: - correct_smiles += 1 - if pred_inchi and gt_sets["inchi"] and pred_inchi in gt_sets["inchi"]: - correct_inchi += 1 - - accuracy_smiles = correct_smiles / total if total else 0.0 - accuracy_inchi = correct_inchi / total if total else 0.0 - result = { - "total": total, + invalid_predictions = 0 + missing_predictions = 0 + tanimoto_scores: List[float] = [] + + for key, gt in structures.items(): + pred_entry = predictions.get(key) + if pred_entry is None: + missing_predictions += 1 + tanimoto_scores.append(0.0) + continue + + canonical = pred_entry["canonical"] + if not canonical: + invalid_predictions += 1 + tanimoto_scores.append(0.0) + continue + + if canonical == gt.get("canonical"): + correct_smiles += 1 + if pred_entry.get("inchi") and pred_entry["inchi"] == gt.get("inchi"): + correct_inchi += 1 + + score = _tanimoto(pred_entry.get("fingerprint"), gt.get("fingerprint")) + tanimoto_scores.append(score if score is not None else 0.0) + + metrics = { + "structures_total": total, + "predictions_total": len(predictions), + "missing_predictions": missing_predictions, + "invalid_predictions": invalid_predictions, "correct_smiles": correct_smiles, - "accuracy_smiles": accuracy_smiles, + "accuracy_smiles": correct_smiles / total if total else 0.0, "correct_inchi": correct_inchi, - "accuracy_inchi": accuracy_inchi, + "accuracy_inchi": correct_inchi / total if total else 0.0, + "tanimoto_mean": _mean(tanimoto_scores), } - # Save metrics + + # Reaction metrics + gt_reactions_total = sum(len(edges) for edges in reaction_truth.values()) + pred_reactions_total = sum(len(edges) for edges in reaction_predictions.values()) + correct_reactions = 0 + + if gt_reactions_total or pred_reactions_total: + for file_name, gt_edges in reaction_truth.items(): + predicted_edges = reaction_predictions.get(file_name, set()) + correct_reactions += len(gt_edges & predicted_edges) + + precision = ( + correct_reactions / pred_reactions_total if pred_reactions_total else 0.0 + ) + recall = correct_reactions / gt_reactions_total if gt_reactions_total else 0.0 + f1 = ( + 2 * precision * recall / (precision + recall) + if precision and recall + else 0.0 + ) + + metrics.update( + { + "reactions_total": gt_reactions_total, + "reactions_predicted": pred_reactions_total, + "reactions_correct": correct_reactions, + "reactions_precision": precision, + "reactions_recall": recall, + "reactions_f1": f1, + } + ) + out_file = out / (pred.stem + "_metrics.json") with out_file.open("w", encoding="utf-8") as fh: - json.dump(result, fh, indent=2) + json.dump(metrics, fh, indent=2) typer.echo(f"Wrote evaluation metrics to {out_file}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover - CLI entry point app() diff --git a/src/molmole_research/extractor.py b/src/molmole_research/extractor.py index 09847fe..2f0f839 100644 --- a/src/molmole_research/extractor.py +++ b/src/molmole_research/extractor.py @@ -1,25 +1,28 @@ -"""Run OCSR extraction using vision‑enabled language models. - -This module defines a CLI that iterates over a directory of images and calls a -specified model to convert each image into a SMILES string (or other textual -representation). By default the extractor uses the OpenAI ChatCompletion -endpoint with a vision‑enabled model such as ``gpt-4o`` or -``gpt-4-vision-preview``. You may override the API base and key via -environment variables or command‑line options. - -The output is saved as a JSONL file where each line contains the image file -name, the extracted text, and optionally the raw model response. The output -file name is prefixed with the model name and timestamp to facilitate -experiment management. +"""Run holistic OCSR extraction on MolMole patent pages. + +MolMole pages often contain several molecular depictions connected by reaction +arrows. Rather than cropping every bounding box independently, the extractor +now feeds the *entire* patent page to a multimodal model and asks for all +visible structures (and reactions) in a single pass. This mirrors the +evaluation setup described in the MolMole paper where the model must reason +about the full context of the drawing. + +Each prediction written to disk corresponds to a single page and includes the +raw model response together with any parsed ``structures``/``reactions`` JSON +payload returned by the model. Downstream evaluation can therefore associate +predicted SMILES with MolMole's annotated structure identifiers while also +tracking reaction graphs. """ from __future__ import annotations import base64 import datetime as dt +import io import json +import re from pathlib import Path -from typing import List, Optional +from typing import Dict, Iterable, List, Optional import typer @@ -31,21 +34,98 @@ import openai # type: ignore[assignment] except ImportError: openai = None # type: ignore[assignment] +from PIL import Image from tqdm import tqdm app = typer.Typer(add_completion=False, help="Run extraction experiments") -def _encode_image_to_base64(path: Path) -> str: - """Read an image file and return a base64‑encoded string.""" - with open(path, "rb") as fh: - data = fh.read() - return base64.b64encode(data).decode("utf-8") +def _encode_bytes_to_base64(data: bytes, image_format: str = "png") -> str: + """Return a base64 data URI for the provided image bytes.""" + + encoded = base64.b64encode(data).decode("utf-8") + return f"data:image/{image_format};base64,{encoded}" + + +def _load_labels(labels_path: Path) -> List[Dict[str, object]]: + """Load ``labels.json`` produced by the downloader.""" + + data = json.loads(labels_path.read_text(encoding="utf-8")) + if not isinstance(data, list): # pragma: no cover - defensive + raise ValueError("labels.json must contain a list of page records") + return data + + +def _iterate_pages( + dataset_root: Path, labels: Iterable[Dict[str, object]] +) -> Iterable[Dict[str, object]]: + """Yield page dictionaries with resolved image paths.""" + + for page in labels: + if not isinstance(page, dict): + continue + file_name = page.get("file_name") + if not isinstance(file_name, str) or not file_name: + continue + image_path = dataset_root / file_name + yield { + "file_name": file_name, + "page_id": page.get("page_id"), + "image_path": image_path, + "width": page.get("width"), + "height": page.get("height"), + } + + +def _format_prompt(page: Dict[str, object]) -> str: + """Create the default holistic extraction prompt for a page.""" + + return ( + "You are analysing a patent page that may contain multiple chemical structures and" + " reaction schemes. Inspect the entire image and enumerate every distinct molecule" + " you can identify without external hints. Assign integer identifiers starting at" + " 0 in reading order (left-to-right, top-to-bottom). For each molecule, provide a" + " canonical SMILES string and optionally short notes. When reaction arrows connect" + " structures, describe the reactions by referencing the identifiers of the" + " reactant and product molecules. Respond in JSON with this shape:\n" + '{"structures": [{"id": int, "smiles": str, "notes": optional str}], ' + '"reactions": [{"reactants": [int], "products": [int], "conditions": [str]}]}.' + " Use empty strings for structures you cannot recognise." + ) + + +def _parse_json_response(text: str) -> Optional[Dict[str, object]]: + """Extract the first valid JSON object from ``text`` if present.""" + + cleaned = text.strip() + if not cleaned: + return None + + # Try a direct parse first. + try: + parsed = json.loads(cleaned) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + + # Attempt to locate the first JSON object in the text (e.g. inside a code block). + match = re.search(r"\{.*\}", cleaned, re.DOTALL) + if not match: + return None + snippet = match.group(0) + try: + parsed = json.loads(snippet) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + return None + return None def call_openai_model( model: str, - image_path: Path, + image_bytes: bytes, prompt: str, api_key: Optional[str] = None, api_base: Optional[str] = None, @@ -56,7 +136,7 @@ def call_openai_model( Args: model: name of the OpenAI model (e.g. ``gpt-4o``, ``gpt-4-vision-preview``) - image_path: path to the image file to process + image_bytes: PNG-encoded bytes representing the cropped structure prompt: instruction prompt to accompany the image api_key: OpenAI API key (overrides environment variable) api_base: custom API base for self‑hosted deployments (optional) @@ -82,8 +162,7 @@ def call_openai_model( openai.api_base = api_base # type: ignore[assignment] # Prepare the payload using the data URI scheme - b64 = _encode_image_to_base64(image_path) - data_uri = f"data:image/png;base64,{b64}" + data_uri = _encode_bytes_to_base64(image_bytes) messages = [ { "role": "system", @@ -118,23 +197,21 @@ def call_openai_model( raise RuntimeError(f"OpenAI API call failed: {exc}") -def default_prompt() -> str: - """Return a default prompt instructing the model to produce a SMILES string.""" - return ( - "Given an image of a chemical structure, output only the corresponding SMILES string. " - "If the structure cannot be interpreted, return an empty string." - ) +def default_prompt(page: Dict[str, object]) -> str: + """Return the holistic extraction prompt for a page.""" + + return _format_prompt(page) @app.command("run") def run_extraction( model: str = typer.Option(..., "--model", help="Name of the vision model to use (e.g. gpt-4o)"), dataset_dir: Path = typer.Option( - Path("data/images/images"), + Path("data/molmole"), exists=True, file_okay=False, dir_okay=True, - help="Directory containing images downloaded by the downloader", + help="Directory containing the MolMole snapshot", ), out: Path = typer.Option( Path("results"), help="Directory where the JSONL results will be saved" @@ -143,36 +220,60 @@ def run_extraction( api_key: Optional[str] = typer.Option(None, help="Override OpenAI API key"), temperature: float = typer.Option(0.0, help="Sampling temperature"), max_tokens: int = typer.Option(256, help="Maximum number of tokens in the response"), + labels_path: Optional[Path] = typer.Option( + None, + help="Path to labels.json (defaults to /labels.json)", + ), ) -> None: """Run the extraction pipeline over all images in ``dataset_dir``. - The function iterates over each image file (with extensions .png, .jpg, .jpeg), - calls the selected model and writes a JSONL file with predictions. The - output file is named ``_.jsonl`` and stored in ``out``. + The extractor loads ``labels.json`` generated by the downloader to discover + the available pages and, for each page, sends the complete image to the + multimodal model. The prompt does not reveal any annotation metadata: the + model must infer how many structures and reactions are present directly from + the pixels. One JSONL record is produced per page. """ dataset_dir = dataset_dir.resolve() out = out.resolve() out.mkdir(parents=True, exist_ok=True) - prompt = default_prompt() - # Collect image paths - image_paths: List[Path] = [] - for ext in ("*.png", "*.jpg", "*.jpeg"): - image_paths.extend(dataset_dir.glob(ext)) - if not image_paths: - typer.echo(f"No images found in {dataset_dir}") + labels_file = labels_path or (dataset_dir / "labels.json") + if not labels_file.exists(): + typer.echo(f"Could not find labels.json at {labels_file}; run the downloader first") + raise typer.Exit(1) + + labels = _load_labels(labels_file) + + pages = list(_iterate_pages(dataset_dir, labels)) + if not pages: + typer.echo("No pages found in labels.json") raise typer.Exit(1) timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S") out_file = out / f"{model}_{timestamp}.jsonl" - typer.echo(f"Processing {len(image_paths)} images with model {model} …") + typer.echo(f"Processing {len(pages)} pages with model {model} …") with out_file.open("w", encoding="utf-8") as fh: - for img_path in tqdm(sorted(image_paths), desc="Extracting", unit="img"): + for page in tqdm(pages, desc="Extracting", unit="page"): + image_path = page.get("image_path") + if not isinstance(image_path, Path): # pragma: no cover - defensive + continue + + try: + with Image.open(image_path).convert("RGB") as image: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + image_bytes = buffer.getvalue() + except FileNotFoundError: + typer.echo(f"Missing image {image_path}") + continue + + prompt = default_prompt(page) + try: text = call_openai_model( model=model, - image_path=img_path, + image_bytes=image_bytes, prompt=prompt, api_key=api_key, api_base=api_base, @@ -180,9 +281,26 @@ def run_extraction( max_tokens=max_tokens, ) except Exception as exc: # pragma: no cover - network errors - typer.echo(f"Failed to process {img_path.name}: {exc}") + typer.echo( + f"Failed to process page {page.get('file_name')}: {exc}" + ) text = "" - record = {"file_name": img_path.name, "text": text} + + parsed = _parse_json_response(text) if text else None + + record: Dict[str, object] = { + "file_name": page.get("file_name"), + "page_id": page.get("page_id"), + "prompt": prompt, + "response": text, + } + + if parsed: + if isinstance(parsed.get("structures"), list): + record["predicted_structures"] = parsed["structures"] + if isinstance(parsed.get("reactions"), list): + record["predicted_reactions"] = parsed["reactions"] + fh.write(json.dumps(record) + "\n") typer.echo(f"Saved predictions to {out_file}") diff --git a/tests/test_downloader.py b/tests/test_downloader.py index 770ed83..e0f0d5d 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -1,80 +1,52 @@ -"""Unit tests for the dataset downloader. - -These tests ensure that the downloader correctly saves images and labels -when provided with a small synthetic dataset. External network calls to -HuggingFace are patched out to avoid slowdowns and dependence on external -services. -""" +"""Unit tests for the dataset downloader.""" from __future__ import annotations import json from pathlib import Path -from PIL import Image -import datasets from rdkit import Chem from molmole_research.downloader import download_dataset -def _make_dummy_dataset(num_items: int = 2): - """Construct a small in‑memory dataset with PIL images and SMILES strings.""" - images = [] - smiles = [] - file_names = [] - for i in range(num_items): - img = Image.new("RGB", (2, 2), color=(i * 20, i * 20, i * 20)) - images.append(img) - smiles.append("C") # simplest molecule - file_names.append(f"item_{i}.png") - return datasets.Dataset.from_dict( - { - "image": images, - "smiles": smiles, - "file_name": file_names, - } - ) - - def _create_dummy_repo(repo_dir: Path) -> None: """Create a fake HuggingFace repository layout with annotations and MOL files.""" json_dir = repo_dir / "json" mol_dir = repo_dir / "mol" + images_dir = repo_dir / "images_300" json_dir.mkdir(parents=True, exist_ok=True) mol_dir.mkdir(parents=True, exist_ok=True) + images_dir.mkdir(parents=True, exist_ok=True) annotations = { + "categories": [ + {"id": 1, "name": "structure"}, + {"id": 2, "name": "text"}, + ], "images": [ { - "file_name": "item_0.png", - "height": 16, - "width": 16, + "id": 10, + "file_name": "page_0.png", + "height": 100, + "width": 100, "dla_bboxes": [ - {"id": 0, "bbox": [0, 0, 4, 4], "category_id": 1}, - {"id": 1, "bbox": [4, 4, 4, 4], "category_id": 2}, - {"id": 2, "bbox": [8, 8, 4, 4], "category_id": 1}, + {"id": 0, "bbox": [0, 0, 40, 40], "category_id": 1}, + {"id": 1, "bbox": [50, 50, 30, 20], "category_id": 2}, + {"id": 2, "bbox": [60, 0, 30, 30], "category_id": 1}, ], - "reactions": [], - }, - { - "file_name": "item_1.png", - "height": 16, - "width": 16, - "dla_bboxes": [ - {"id": 0, "bbox": [1, 1, 4, 4], "category_id": 1}, + "reactions": [ + {"reactants": [0], "conditions": [1], "products": [2]} ], - "reactions": [], - }, - ] + } + ], } (json_dir / "annotations.json").write_text(json.dumps(annotations), encoding="utf-8") smiles_map = { - "item_0_bbox_0.mol": "C", - "item_0_bbox_2.mol": "CC", - "item_1_bbox_0.mol": "O", + "page_0_bbox_0.mol": "C", + "page_0_bbox_2.mol": "CC", } for name, smiles in smiles_map.items(): mol = Chem.MolFromSmiles(smiles) @@ -83,51 +55,40 @@ def _create_dummy_repo(repo_dir: Path) -> None: (mol_dir / name).write_text(block, encoding="utf-8") -def test_downloader_saves_images_and_labels(monkeypatch, tmp_path): - """Verify that the downloader writes images and a labels.json file.""" - dummy_ds = _make_dummy_dataset(3) +def test_downloader_builds_labels_with_reactions(monkeypatch, tmp_path): + """Verify that the downloader writes labels with structures and reactions.""" repo_dir = tmp_path / "hf_repo" _create_dummy_repo(repo_dir) - # Patch datasets.load_dataset to return our dummy dataset regardless of the input - def fake_load_dataset(*args, **kwargs): # noqa: D401 - return dummy_ds - - monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) - - def fake_snapshot_download(*args, **kwargs): + def fake_snapshot_download(*args, **kwargs): # noqa: D401 return str(repo_dir) monkeypatch.setattr("molmole_research.downloader.snapshot_download", fake_snapshot_download) out_dir = tmp_path / "download" - download_dataset(dataset="dummy", split="train", out=out_dir) - # Check that images were saved - images_dir = out_dir / "images" - assert images_dir.exists() and images_dir.is_dir() - saved_files = sorted(p.name for p in images_dir.iterdir()) - assert saved_files == [f"item_{i}.png" for i in range(3)] - - # Check labels.json - labels_path = out_dir / "labels.json" + download_dataset(dataset="dummy", out=out_dir, images_subdir=None, annotations_file=None) + + labels_path = repo_dir / "labels.json" assert labels_path.exists() labels = json.loads(labels_path.read_text()) - assert len(labels) == 3 + assert len(labels) == 1 - item0 = next(entry for entry in labels if entry["file_name"] == "item_0.png") - assert item0["smiles"] == "C" - assert item0["smiles_list"] == ["C", "CC"] - assert len(item0["structures"]) == 2 - for struct in item0["structures"]: - assert struct["mol_file"].startswith("mol/") - assert (out_dir / struct["mol_file"]).exists() + entry = labels[0] + assert entry["file_name"] == "images_300/page_0.png" + assert entry["page_id"] == 10 + assert entry["smiles"] == "C" + assert entry["smiles_list"] == ["C", "CC"] - item1 = next(entry for entry in labels if entry["file_name"] == "item_1.png") - assert item1["smiles"] == "C" - # SMILES list should include dataset smile plus MOL-derived value - assert item1["smiles_list"] == ["C", "O"] - assert len(item1["structures"]) == 1 + layout_boxes = entry.get("layout_boxes") + assert isinstance(layout_boxes, list) and len(layout_boxes) == 3 + assert {box["category_name"] for box in layout_boxes} == {"structure", "text"} + + structures = entry.get("structures") + assert isinstance(structures, list) and len(structures) == 2 + for struct in structures: + assert struct["mol_file"].startswith("mol/") + assert struct["smiles"] in {"C", "CC"} - # Ensure annotations JSON is copied over - assert (out_dir / "annotations.json").exists() + reactions = entry.get("reactions") + assert reactions == [{"reactants": [0], "conditions": [1], "products": [2]}] diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index c28e540..2f90988 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,8 +1,4 @@ -"""Tests for the evaluation script. - -These tests ensure that the evaluator correctly computes SMILES and InChI -accuracy. RDKit is used to canonicalise SMILES and generate InChI keys. -""" +"""Tests for the MolMole evaluator.""" from __future__ import annotations @@ -26,20 +22,35 @@ def _write_predictions(pred_path: Path, entries): fh.write(json.dumps(entry) + "\n") -def test_evaluator_accuracy(tmp_path): - """Evaluate a small set of predictions and verify the computed accuracy.""" - # Ground truth labels +def test_evaluator_metrics(tmp_path): + """Evaluate predictions and verify MolMole-style metrics.""" + labels = [ - {"file_name": "img_0.png", "smiles": "C", "smiles_list": ["C"]}, - {"file_name": "img_1.png", "smiles_list": ["CC"]}, + { + "file_name": "images_300/page.png", + "structures": [ + {"id": 0, "smiles": "C"}, + {"id": 1, "smiles": "CC"}, + ], + "reactions": [ + {"reactants": [0], "products": [1], "conditions": []} + ], + } ] - dataset_dir = tmp_path / "data" + dataset_dir = tmp_path / "dataset" _write_labels(dataset_dir, labels) - # Predictions: one correct, one incorrect preds = [ - {"file_name": "img_0.png", "text": "C"}, - {"file_name": "img_1.png", "text": "C"}, + { + "file_name": "images_300/page.png", + "predicted_structures": [ + {"id": 0, "smiles": "C"}, + {"id": 1, "smiles": ""}, + ], + "predicted_reactions": [ + {"reactants": [0], "products": [1], "conditions": []} + ], + } ] pred_path = tmp_path / "preds.jsonl" _write_predictions(pred_path, preds) @@ -47,33 +58,53 @@ def test_evaluator_accuracy(tmp_path): results_dir = tmp_path / "results" evaluator.evaluate(pred=pred_path, dataset_dir=dataset_dir, out=results_dir) - # Load metrics metrics_files = list(results_dir.iterdir()) assert metrics_files - with metrics_files[0].open("r", encoding="utf-8") as fh: - metrics = json.load(fh) - assert metrics["total"] == 2 - # Only the first prediction is correct + metrics = json.loads(metrics_files[0].read_text()) + + assert metrics["structures_total"] == 2 + assert metrics["predictions_total"] == 2 assert metrics["correct_smiles"] == 1 assert metrics["correct_inchi"] == 1 + assert metrics["invalid_predictions"] == 1 + assert metrics["missing_predictions"] == 0 assert abs(metrics["accuracy_smiles"] - 0.5) < 1e-6 + assert abs(metrics["accuracy_inchi"] - 0.5) < 1e-6 + assert abs(metrics["tanimoto_mean"] - 0.5) < 1e-6 + assert metrics["reactions_total"] == 1 + assert metrics["reactions_predicted"] == 1 + assert metrics["reactions_correct"] == 1 + assert abs(metrics["reactions_precision"] - 1.0) < 1e-6 + assert abs(metrics["reactions_recall"] - 1.0) < 1e-6 + assert abs(metrics["reactions_f1"] - 1.0) < 1e-6 -def test_evaluator_handles_multiple_smiles(tmp_path): - """Predictions should match if any ground truth SMILES is correct.""" +def test_evaluator_handles_missing_predictions(tmp_path): + """Missing predictions should be counted and penalised.""" labels = [ { - "file_name": "img_multi.png", - "smiles": None, - "smiles_list": ["CC", "O"], - "structures": [{"smiles": "CC"}, {"smiles": "O"}], + "file_name": "images_300/page.png", + "structures": [ + {"id": 0, "smiles": "C"}, + {"id": 1, "smiles": "CC"}, + ], + "reactions": [ + {"reactants": [0], "products": [1], "conditions": []} + ], } ] - dataset_dir = tmp_path / "data" + dataset_dir = tmp_path / "dataset" _write_labels(dataset_dir, labels) - preds = [{"file_name": "img_multi.png", "text": "O"}] + preds = [ + { + "file_name": "images_300/page.png", + "predicted_structures": [ + {"id": 0, "smiles": "C"}, + ], + } + ] pred_path = tmp_path / "preds.jsonl" _write_predictions(pred_path, preds) @@ -81,9 +112,12 @@ def test_evaluator_handles_multiple_smiles(tmp_path): evaluator.evaluate(pred=pred_path, dataset_dir=dataset_dir, out=results_dir) metrics_files = list(results_dir.iterdir()) - assert metrics_files - with metrics_files[0].open("r", encoding="utf-8") as fh: - metrics = json.load(fh) - assert metrics["total"] == 1 + metrics = json.loads(metrics_files[0].read_text()) + + assert metrics["structures_total"] == 2 + assert metrics["missing_predictions"] == 1 assert metrics["correct_smiles"] == 1 - assert metrics["correct_inchi"] == 1 + assert metrics["predictions_total"] == 1 + assert metrics.get("reactions_total", 0) == 1 + assert metrics.get("reactions_predicted", 0) == 0 + assert metrics.get("reactions_correct", 0) == 0 diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 576b341..f78e185 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,12 +1,8 @@ -"""Tests for the OCSR extractor. - -We patch the OpenAI API call so that no network requests are made during -testing. The extractor should iterate through the dataset directory, call -``call_openai_model`` on each image and write a JSONL file with predictions. -""" +"""Tests for the MolMole extractor.""" from __future__ import annotations +import io import json from pathlib import Path @@ -15,21 +11,47 @@ from molmole_research import extractor -def _create_images(dir_path: Path, count: int) -> None: - dir_path.mkdir(parents=True, exist_ok=True) - for i in range(count): - img = Image.new("RGB", (4, 4), color=(i * 50, i * 50, i * 50)) - img.save(dir_path / f"img_{i}.png") +def _create_dataset(tmp_path: Path) -> Path: + dataset_dir = tmp_path / "dataset" + images_dir = dataset_dir / "images_300" + images_dir.mkdir(parents=True, exist_ok=True) + + image_path = images_dir / "page.png" + Image.new("RGB", (100, 100), color="white").save(image_path) + + labels = [ + { + "file_name": "images_300/page.png", + "page_id": 1, + "structures": [ + {"id": 0, "bbox": [0, 0, 50, 50]}, + {"id": 1, "bbox": [50, 50, 50, 50]}, + ], + } + ] + (dataset_dir / "labels.json").write_text(json.dumps(labels), encoding="utf-8") + return dataset_dir + +def test_extractor_produces_page_level_predictions(monkeypatch, tmp_path): + """The extractor should emit one record per page with parsed structures.""" -def test_extractor_produces_predictions(monkeypatch, tmp_path): - """The extractor should write a JSONL file containing predictions for each image.""" - dataset_dir = tmp_path / "data" - _create_images(dataset_dir, 2) + dataset_dir = _create_dataset(tmp_path) - # Patch call_openai_model to return a fixed SMILES string - def fake_call_openai_model(*args, **kwargs): # noqa: D401 - return "C" + def fake_call_openai_model(*, image_bytes, **kwargs): # noqa: D401 + image = Image.open(io.BytesIO(image_bytes)) + assert image.size == (100, 100) + return json.dumps( + { + "structures": [ + {"id": 0, "smiles": "C"}, + {"id": 1, "smiles": "CC"}, + ], + "reactions": [ + {"reactants": [0], "products": [1], "conditions": []} + ], + } + ) monkeypatch.setattr(extractor, "call_openai_model", fake_call_openai_model) @@ -42,13 +64,23 @@ def fake_call_openai_model(*args, **kwargs): # noqa: D401 api_key=None, temperature=0.0, max_tokens=32, + labels_path=None, ) - # There should be exactly one JSONL file in out_dir + files = list(out_dir.iterdir()) assert len(files) == 1 and files[0].suffix == ".jsonl" - # Check its contents - contents = files[0].read_text().strip().splitlines() - assert len(contents) == 2 - for line in contents: - record = json.loads(line) - assert record["text"] == "C" + + with files[0].open("r", encoding="utf-8") as fh: + lines = [json.loads(line) for line in fh if line.strip()] + + assert len(lines) == 1 + record = lines[0] + assert record["file_name"] == "images_300/page.png" + assert record["predicted_structures"] == [ + {"id": 0, "smiles": "C"}, + {"id": 1, "smiles": "CC"}, + ] + assert record["predicted_reactions"] == [ + {"reactants": [0], "products": [1], "conditions": []} + ] + assert "structures_expected" not in record From 6e179a3fbd30c4cbc52a54384fc594f4b4d9ad11 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Tue, 28 Oct 2025 16:27:35 +0100 Subject: [PATCH 2/2] Adopt native typing conventions and add repo guidelines --- AGENTS.md | 23 +++++ src/molmole_research/__init__.py | 24 +---- src/molmole_research/downloader.py | 77 +++++++++------- src/molmole_research/evaluator.py | 142 ++++++++++++++++------------- src/molmole_research/extractor.py | 32 +++---- src/molmole_research/runner.py | 14 ++- tests/test_downloader.py | 6 +- tests/test_evaluator.py | 12 +-- tests/test_extractor.py | 10 +- 9 files changed, 172 insertions(+), 168 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..fb909af --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,23 @@ +## Developer Workflow +Follow SOLID principles. +Always ensure that tests pass before committing. + + +## Coding Standards +- Keep routers thin; logic in `services/`. +- Repositories are the only DB touchpoint from services. +- DTOs (schemas) are versioned; never leak ORM models. +- Return 201 on creates, 202 on async accepted; 429 on limits; 422 on validation. +- Docstrings: Google style; type hints everywhere. +- Avoid defining `__all__` unless absolutely necessary, and never use `from module import *`; prefer explicit imports (ruff F403/F405 guard against this). +- Keep package `__init__.py` files empty aside from optional module docstrings; avoid re-exporting symbols from them. +- Never add `from __future__ import annotations`; native Python 3.12 typing is required. +- Prefer native union syntax (e.g., `str | None`) instead of `typing.Optional[...]` or `typing.Union`. + +## Development Workflow + +- All PRs require: ✅ lint, ✅ type-check, ✅ tests. +- Before committing, run all relevant project checks locally and ensure they pass. +- Write concise PR descriptions (why + what), not just code diffs. +- Instructions from user may be asked in any language, but code and response should be in English. +- Do not use `docker compose up -d` (or any detached Compose command) in development-facing scripts or docs. diff --git a/src/molmole_research/__init__.py b/src/molmole_research/__init__.py index 7a9c337..403d6b0 100644 --- a/src/molmole_research/__init__.py +++ b/src/molmole_research/__init__.py @@ -1,23 +1 @@ -"""Top‑level package for MolMole OCSR research environment. - -This package contains modules to download the MolMole benchmark dataset, -extract chemical structures from images using vision‑language models, -evaluate predictions against ground truth and orchestrate experiments. - -The package is designed to be executed via the command line, for example: - - python -m molmole_research.downloader --help - python -m molmole_research.extractor --help - python -m molmole_research.evaluator --help - python -m molmole_research.runner --help - -""" - -__all__ = [ - "downloader", - "extractor", - "evaluator", - "runner", -] - -__version__ = "0.1.0" +"""Top-level package for MolMole OCSR research environment.""" diff --git a/src/molmole_research/downloader.py b/src/molmole_research/downloader.py index 0190f28..4a69972 100644 --- a/src/molmole_research/downloader.py +++ b/src/molmole_research/downloader.py @@ -9,11 +9,9 @@ evaluation. """ -from __future__ import annotations - import json from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Iterable import typer from huggingface_hub import snapshot_download @@ -27,11 +25,11 @@ app = typer.Typer(add_completion=False, help="Download the MolMole dataset") -def _unique_preserve_order(values: Iterable[str]) -> List[str]: +def _unique_preserve_order(values: Iterable[str]) -> list[str]: """Return a list with duplicates removed while preserving order.""" - seen: Set[str] = set() - result: List[str] = [] + seen: set[str] = set() + result: list[str] = [] for value in values: if value not in seen: seen.add(value) @@ -48,7 +46,7 @@ def _ensure_rdkit() -> None: ) -def _mol_to_smiles(mol_path: Path) -> Optional[str]: +def _mol_to_smiles(mol_path: Path) -> str | None: """Convert a MOL file into a canonical SMILES string.""" try: @@ -60,7 +58,7 @@ def _mol_to_smiles(mol_path: Path) -> Optional[str]: return None -def _discover_images_dir(dataset_root: Path, preferred: Optional[str]) -> Path: +def _discover_images_dir(dataset_root: Path, preferred: str | None) -> Path: """Return the directory containing page level images.""" if preferred: @@ -79,7 +77,7 @@ def _discover_images_dir(dataset_root: Path, preferred: Optional[str]) -> Path: ) -def _discover_annotation_file(dataset_root: Path, filename: Optional[str]) -> Path: +def _discover_annotation_file(dataset_root: Path, filename: str | None) -> Path: """Locate the annotation JSON shipped with the dataset.""" search_dirs = [dataset_root / "json", dataset_root] @@ -99,11 +97,11 @@ def _discover_annotation_file(dataset_root: Path, filename: Optional[str]) -> Pa raise FileNotFoundError(f"Unable to locate annotation file at {location}") -def _download_snapshot(dataset: str, out: Path, revision: Optional[str]) -> Path: +def _download_snapshot(dataset: str, out: Path, revision: str | None) -> Path: """Download the dataset snapshot into ``out`` and return the local path.""" typer.echo(f"Downloading dataset {dataset} to {out} …") - kwargs: Dict[str, Any] = { + kwargs: dict[str, Any] = { "repo_type": "dataset", "local_dir": str(out), } @@ -119,17 +117,21 @@ def _build_labels( images_dir: Path, annotation_file: Path, mol_dir: Path, -) -> Tuple[List[Dict[str, Any]], Dict[str, int]]: +) -> tuple[list[dict[str, Any]], dict[str, int]]: """Create structured labels enriched with canonical SMILES.""" data = json.loads(annotation_file.read_text(encoding="utf-8")) - records: List[Dict[str, Any]] = [] + records: list[dict[str, Any]] = [] - category_names: Dict[int, str] = { - category.get("id"): category.get("name", "unknown") - for category in data.get("categories", []) - if isinstance(category, dict) - } + category_names: dict[int, str] = {} + for category in data.get("categories", []): + if not isinstance(category, dict): + continue + category_id = category.get("id") + if not isinstance(category_id, int): + continue + name = category.get("name") + category_names[category_id] = name if isinstance(name, str) else "unknown" total_structures = 0 smiles_converted = 0 @@ -149,7 +151,7 @@ def _build_labels( if not file_name: continue - entry: Dict[str, Any] = { + entry: dict[str, Any] = { "file_name": str(image_prefix / file_name), "width": image_meta.get("width"), "height": image_meta.get("height"), @@ -157,9 +159,9 @@ def _build_labels( if image_meta.get("id") is not None: entry["page_id"] = image_meta["id"] - layout_boxes: List[Dict[str, Any]] = [] - structures: List[Dict[str, Any]] = [] - smiles_candidates: List[str] = [] + layout_boxes: list[dict[str, Any]] = [] + structures: list[dict[str, Any]] = [] + smiles_candidates: list[str] = [] for bbox in image_meta.get("dla_bboxes", []): if not isinstance(bbox, dict): @@ -169,10 +171,13 @@ def _build_labels( if bbox_id is None: continue - category_id = bbox.get("category_id") - category_name = category_names.get(category_id, "unknown") + category_id_raw = bbox.get("category_id") + category_id = category_id_raw if isinstance(category_id_raw, int) else None + category_name = ( + category_names.get(category_id, "unknown") if category_id is not None else "unknown" + ) - record: Dict[str, Any] = { + record: dict[str, Any] = { "id": bbox_id, "bbox": bbox.get("bbox"), "category_id": category_id, @@ -183,8 +188,8 @@ def _build_labels( total_structures += 1 mol_name = f"{Path(file_name).stem}_bbox_{bbox_id}.mol" mol_path = mol_dir / mol_name - smiles: Optional[str] = None - mol_relative: Optional[str] = None + smiles: str | None = None + mol_relative: str | None = None if mol_path.exists(): mol_relative = str(mol_path.relative_to(dataset_root)) @@ -197,10 +202,12 @@ def _build_labels( else: missing_mols += 1 - record.update({ - "mol_file": mol_relative, - "smiles": smiles, - }) + record.update( + { + "mol_file": mol_relative, + "smiles": smiles, + } + ) structures.append(dict(record)) layout_boxes.append(record) @@ -220,7 +227,7 @@ def _build_labels( records.append(entry) - stats = { + stats: dict[str, int] = { "images": len(records), "structures_total": total_structures, "structures_with_smiles": smiles_converted, @@ -241,15 +248,15 @@ def download_dataset( Path("data/molmole"), help="Directory where the dataset snapshot and derived labels will be stored", ), - images_subdir: Optional[str] = typer.Option( + images_subdir: str | None = typer.Option( None, help="Name of the images folder inside the dataset (defaults to auto-detect)", ), - annotations_file: Optional[str] = typer.Option( + annotations_file: str | None = typer.Option( None, help="JSON file containing annotations (defaults to the first JSON in the json/ directory)", ), - revision: Optional[str] = typer.Option( + revision: str | None = typer.Option( None, help="Optional dataset revision to download", ), diff --git a/src/molmole_research/evaluator.py b/src/molmole_research/evaluator.py index b8dbb66..bf97d7e 100644 --- a/src/molmole_research/evaluator.py +++ b/src/molmole_research/evaluator.py @@ -17,13 +17,11 @@ reaction edges (reactant/product identifier sets plus optional conditions). """ -from __future__ import annotations - import json import re from pathlib import Path from statistics import mean -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Iterable, TypedDict, cast import typer @@ -38,10 +36,38 @@ from tqdm import tqdm +if TYPE_CHECKING: # pragma: no cover - typing helpers only + from rdkit.DataStructs.cDataStructs import ExplicitBitVect +else: # pragma: no cover - runtime fallback when RDKit stubs unavailable + ExplicitBitVect = Any + +ReactionEdge = tuple[tuple[int, ...], tuple[int, ...], tuple[str, ...]] + + +class StructureInfo(TypedDict): + """Canonicalised ground-truth structure information.""" + + canonical: str + inchi: str | None + fingerprint: ExplicitBitVect | None + + +class PredictionInfo(TypedDict, total=False): + """Normalised prediction payload for a single structure.""" + + canonical: str | None + inchi: str | None + fingerprint: ExplicitBitVect | None + + +ReactionMap = dict[str, set[ReactionEdge]] +StructureMap = dict[tuple[str, int], StructureInfo] +PredictionMap = dict[tuple[str, int], PredictionInfo] + app = typer.Typer(add_completion=False, help="Evaluate OCSR predictions") -def _canonical_smiles(smiles: str) -> Optional[str]: +def _canonical_smiles(smiles: str) -> str | None: """Return the canonical SMILES for ``smiles`` or ``None`` if invalid.""" if Chem is None: @@ -56,7 +82,7 @@ def _canonical_smiles(smiles: str) -> Optional[str]: return None -def _inchi_key(smiles: str) -> Optional[str]: +def _inchi_key(smiles: str) -> str | None: """Return the InChI key for ``smiles`` if possible.""" if Chem is None: @@ -71,7 +97,7 @@ def _inchi_key(smiles: str) -> Optional[str]: return None -def _fingerprint(smiles: str): +def _fingerprint(smiles: str) -> ExplicitBitVect | None: """Compute a Morgan fingerprint for ``smiles`` or return ``None``.""" if Chem is None or AllChem is None: # pragma: no cover - RDKit guard @@ -80,18 +106,22 @@ def _fingerprint(smiles: str): mol = Chem.MolFromSmiles(smiles) # type: ignore[call-arg] if mol is None: return None - return AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) + allchem = cast(Any, AllChem) + return cast( + ExplicitBitVect | None, + allchem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048), + ) except Exception: return None -def _tanimoto(fp_a, fp_b) -> Optional[float]: # type: ignore[override] +def _tanimoto(fp_a: ExplicitBitVect | None, fp_b: ExplicitBitVect | None) -> float | None: if DataStructs is None or fp_a is None or fp_b is None: # pragma: no cover return None return float(DataStructs.TanimotoSimilarity(fp_a, fp_b)) -def _normalise_prediction(text: str) -> Optional[str]: +def _normalise_prediction(text: str) -> str | None: """Extract the first valid SMILES token from ``text``.""" cleaned = text.strip() @@ -107,17 +137,15 @@ def _normalise_prediction(text: str) -> Optional[str]: return _canonical_smiles(cleaned) -def _load_labels(labels_path: Path) -> List[Dict[str, object]]: +def _load_labels(labels_path: Path) -> list[dict[str, object]]: data = json.loads(labels_path.read_text(encoding="utf-8")) if not isinstance(data, list): raise ValueError("labels.json must contain a list of page records") return data -def _index_structures( - labels: Iterable[Dict[str, object]] -) -> Dict[Tuple[str, int], Dict[str, object]]: - mapping: Dict[Tuple[str, int], Dict[str, object]] = {} +def _index_structures(labels: Iterable[dict[str, object]]) -> StructureMap: + mapping: StructureMap = {} for page in labels: if not isinstance(page, dict): continue @@ -137,18 +165,18 @@ def _index_structures( canonical = _canonical_smiles(smiles) if not canonical: continue - mapping[(file_name, int(struct_id))] = { - "canonical": canonical, - "inchi": _inchi_key(canonical), - "fingerprint": _fingerprint(canonical), - } + mapping[(file_name, int(struct_id))] = StructureInfo( + canonical=canonical, + inchi=_inchi_key(canonical), + fingerprint=_fingerprint(canonical), + ) return mapping -def _index_reactions(labels: Iterable[Dict[str, object]]) -> Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]]: +def _index_reactions(labels: Iterable[dict[str, object]]) -> ReactionMap: """Collect ground truth reactions per page.""" - reaction_map: Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]] = {} + reaction_map: ReactionMap = {} for page in labels: if not isinstance(page, dict): continue @@ -158,7 +186,7 @@ def _index_reactions(labels: Iterable[Dict[str, object]]) -> Dict[str, Set[Tuple reactions = page.get("reactions") if not isinstance(reactions, list): continue - normalised: Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]] = set() + normalised: set[ReactionEdge] = set() for reaction in reactions: normalised_entry = _normalise_reaction(reaction) if isinstance(reaction, dict) else None if normalised_entry: @@ -168,14 +196,14 @@ def _index_reactions(labels: Iterable[Dict[str, object]]) -> Dict[str, Set[Tuple return reaction_map -def _normalise_reaction(entry: Dict[str, object]) -> Optional[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]: +def _normalise_reaction(entry: dict[str, object]) -> ReactionEdge | None: """Convert a reaction dict into a hashable tuple.""" if not isinstance(entry, dict): return None - def _normalise_ids(values) -> Tuple[int, ...]: - ids: List[int] = [] + def _normalise_ids(values: object) -> tuple[int, ...]: + ids: list[int] = [] if isinstance(values, list): for item in values: try: @@ -189,9 +217,10 @@ def _normalise_ids(values) -> Tuple[int, ...]: if not reactants or not products: return None - conditions: List[str] = [] - if isinstance(entry.get("conditions"), list): - for cond in entry["conditions"]: + conditions: list[str] = [] + conditions_raw = entry.get("conditions") + if isinstance(conditions_raw, list): + for cond in conditions_raw: if isinstance(cond, str): stripped = cond.strip() if stripped: @@ -199,16 +228,11 @@ def _normalise_ids(values) -> Tuple[int, ...]: return reactants, products, tuple(sorted(set(conditions))) -def _load_predictions( - pred_path: Path, -) -> Tuple[ - Dict[Tuple[str, int], Dict[str, object]], - Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]], -]: +def _load_predictions(pred_path: Path) -> tuple[PredictionMap, ReactionMap]: """Load structure and reaction predictions from the extractor output.""" - predictions: Dict[Tuple[str, int], Dict[str, object]] = {} - reaction_predictions: Dict[str, Set[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]]] = {} + predictions: PredictionMap = {} + reaction_predictions: ReactionMap = {} with pred_path.open("r", encoding="utf-8") as fh: for line in tqdm(fh, desc="Evaluating", unit="page"): @@ -240,11 +264,11 @@ def _load_predictions( except (TypeError, ValueError): continue canonical = _normalise_prediction(smiles_value) - predictions[(file_name, struct_id_int)] = { - "canonical": canonical, - "inchi": _inchi_key(canonical) if canonical else None, - "fingerprint": _fingerprint(canonical) if canonical else None, - } + predictions[(file_name, struct_id_int)] = PredictionInfo( + canonical=canonical, + inchi=_inchi_key(canonical) if canonical else None, + fingerprint=_fingerprint(canonical) if canonical else None, + ) else: # Backwards compatibility with per-structure JSONL records. struct_id = record.get("structure_id") @@ -256,18 +280,16 @@ def _load_predictions( struct_id_int = None if struct_id_int is not None: canonical = _normalise_prediction(text) - predictions[(file_name, struct_id_int)] = { - "canonical": canonical, - "inchi": _inchi_key(canonical) if canonical else None, - "fingerprint": _fingerprint(canonical) if canonical else None, - } + predictions[(file_name, struct_id_int)] = PredictionInfo( + canonical=canonical, + inchi=_inchi_key(canonical) if canonical else None, + fingerprint=_fingerprint(canonical) if canonical else None, + ) reactions = record.get("predicted_reactions") if isinstance(reactions, list): normalised = { - r - for r in (_normalise_reaction(item) for item in reactions) - if r is not None + r for r in (_normalise_reaction(item) for item in reactions) if r is not None } if normalised: reaction_predictions[file_name] = normalised @@ -319,7 +341,7 @@ def evaluate( correct_inchi = 0 invalid_predictions = 0 missing_predictions = 0 - tanimoto_scores: List[float] = [] + tanimoto_scores: list[float] = [] for key, gt in structures.items(): pred_entry = predictions.get(key) @@ -328,18 +350,20 @@ def evaluate( tanimoto_scores.append(0.0) continue - canonical = pred_entry["canonical"] + canonical = pred_entry.get("canonical") if not canonical: invalid_predictions += 1 tanimoto_scores.append(0.0) continue - if canonical == gt.get("canonical"): + if canonical == gt["canonical"]: correct_smiles += 1 - if pred_entry.get("inchi") and pred_entry["inchi"] == gt.get("inchi"): + inchi_value = pred_entry.get("inchi") + if inchi_value and inchi_value == gt["inchi"]: correct_inchi += 1 - score = _tanimoto(pred_entry.get("fingerprint"), gt.get("fingerprint")) + fingerprint_value = pred_entry.get("fingerprint") + score = _tanimoto(fingerprint_value, gt["fingerprint"]) tanimoto_scores.append(score if score is not None else 0.0) metrics = { @@ -364,15 +388,9 @@ def evaluate( predicted_edges = reaction_predictions.get(file_name, set()) correct_reactions += len(gt_edges & predicted_edges) - precision = ( - correct_reactions / pred_reactions_total if pred_reactions_total else 0.0 - ) + precision = correct_reactions / pred_reactions_total if pred_reactions_total else 0.0 recall = correct_reactions / gt_reactions_total if gt_reactions_total else 0.0 - f1 = ( - 2 * precision * recall / (precision + recall) - if precision and recall - else 0.0 - ) + f1 = 2 * precision * recall / (precision + recall) if precision and recall else 0.0 metrics.update( { diff --git a/src/molmole_research/extractor.py b/src/molmole_research/extractor.py index 2f0f839..d4ac54e 100644 --- a/src/molmole_research/extractor.py +++ b/src/molmole_research/extractor.py @@ -14,15 +14,13 @@ tracking reaction graphs. """ -from __future__ import annotations - import base64 import datetime as dt import io import json import re from pathlib import Path -from typing import Dict, Iterable, List, Optional +from typing import Iterable import typer @@ -47,7 +45,7 @@ def _encode_bytes_to_base64(data: bytes, image_format: str = "png") -> str: return f"data:image/{image_format};base64,{encoded}" -def _load_labels(labels_path: Path) -> List[Dict[str, object]]: +def _load_labels(labels_path: Path) -> list[dict[str, object]]: """Load ``labels.json`` produced by the downloader.""" data = json.loads(labels_path.read_text(encoding="utf-8")) @@ -57,8 +55,8 @@ def _load_labels(labels_path: Path) -> List[Dict[str, object]]: def _iterate_pages( - dataset_root: Path, labels: Iterable[Dict[str, object]] -) -> Iterable[Dict[str, object]]: + dataset_root: Path, labels: Iterable[dict[str, object]] +) -> Iterable[dict[str, object]]: """Yield page dictionaries with resolved image paths.""" for page in labels: @@ -77,7 +75,7 @@ def _iterate_pages( } -def _format_prompt(page: Dict[str, object]) -> str: +def _format_prompt(page: dict[str, object]) -> str: """Create the default holistic extraction prompt for a page.""" return ( @@ -94,7 +92,7 @@ def _format_prompt(page: Dict[str, object]) -> str: ) -def _parse_json_response(text: str) -> Optional[Dict[str, object]]: +def _parse_json_response(text: str) -> dict[str, object] | None: """Extract the first valid JSON object from ``text`` if present.""" cleaned = text.strip() @@ -127,8 +125,8 @@ def call_openai_model( model: str, image_bytes: bytes, prompt: str, - api_key: Optional[str] = None, - api_base: Optional[str] = None, + api_key: str | None = None, + api_base: str | None = None, temperature: float = 0.0, max_tokens: int = 256, ) -> str: @@ -197,7 +195,7 @@ def call_openai_model( raise RuntimeError(f"OpenAI API call failed: {exc}") -def default_prompt(page: Dict[str, object]) -> str: +def default_prompt(page: dict[str, object]) -> str: """Return the holistic extraction prompt for a page.""" return _format_prompt(page) @@ -216,11 +214,11 @@ def run_extraction( out: Path = typer.Option( Path("results"), help="Directory where the JSONL results will be saved" ), - api_base: Optional[str] = typer.Option(None, help="Override OpenAI API base URL"), - api_key: Optional[str] = typer.Option(None, help="Override OpenAI API key"), + api_base: str | None = typer.Option(None, help="Override OpenAI API base URL"), + api_key: str | None = typer.Option(None, help="Override OpenAI API key"), temperature: float = typer.Option(0.0, help="Sampling temperature"), max_tokens: int = typer.Option(256, help="Maximum number of tokens in the response"), - labels_path: Optional[Path] = typer.Option( + labels_path: Path | None = typer.Option( None, help="Path to labels.json (defaults to /labels.json)", ), @@ -281,14 +279,12 @@ def run_extraction( max_tokens=max_tokens, ) except Exception as exc: # pragma: no cover - network errors - typer.echo( - f"Failed to process page {page.get('file_name')}: {exc}" - ) + typer.echo(f"Failed to process page {page.get('file_name')}: {exc}") text = "" parsed = _parse_json_response(text) if text else None - record: Dict[str, object] = { + record: dict[str, object] = { "file_name": page.get("file_name"), "page_id": page.get("page_id"), "prompt": prompt, diff --git a/src/molmole_research/runner.py b/src/molmole_research/runner.py index af28d9f..0c9951b 100644 --- a/src/molmole_research/runner.py +++ b/src/molmole_research/runner.py @@ -24,14 +24,12 @@ (using the underlying extractor) and then evaluate them automatically. """ -from __future__ import annotations - import json import subprocess import sys from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional +from typing import Any import typer import yaml @@ -39,7 +37,7 @@ app = typer.Typer(add_completion=False, help="Run multiple extraction experiments") -def _find_latest_prediction(out_dir: Path, model_prefix: str) -> Optional[Path]: +def _find_latest_prediction(out_dir: Path, model_prefix: str) -> Path | None: """Find the most recent prediction file for a given model prefix. Prediction files are expected to be named ``_YYYYMMDD_HHMMSS.jsonl``. @@ -64,8 +62,8 @@ def run_experiments( results_dir: Path = typer.Option( Path("results"), help="Directory to store predictions and metrics" ), - api_key: Optional[str] = typer.Option(None, help="Override OpenAI API key"), - api_base: Optional[str] = typer.Option(None, help="Override OpenAI API base URL"), + api_key: str | None = typer.Option(None, help="Override OpenAI API key"), + api_base: str | None = typer.Option(None, help="Override OpenAI API base URL"), out: Path = typer.Option(Path("results/summary.json"), help="Path to save summary metrics"), ) -> None: """Run a series of extraction experiments defined in a YAML configuration. @@ -92,7 +90,7 @@ def run_experiments( typer.echo("No experiments defined in configuration file") raise typer.Exit(1) - summary: List[Dict[str, any]] = [] + summary: list[dict[str, Any]] = [] for exp in experiments: name = exp.get("name") or exp.get("model") or "experiment" model = exp["model"] @@ -149,7 +147,7 @@ def run_experiments( with open(metrics_path, "r", encoding="utf-8") as fh: metrics = json.load(fh) else: - metrics = {} + metrics: dict[str, Any] = {} metrics["experiment"] = name metrics["model"] = model metrics["timestamp"] = datetime.now().isoformat() diff --git a/tests/test_downloader.py b/tests/test_downloader.py index e0f0d5d..e48faab 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -1,7 +1,5 @@ """Unit tests for the dataset downloader.""" -from __future__ import annotations - import json from pathlib import Path @@ -36,9 +34,7 @@ def _create_dummy_repo(repo_dir: Path) -> None: {"id": 1, "bbox": [50, 50, 30, 20], "category_id": 2}, {"id": 2, "bbox": [60, 0, 30, 30], "category_id": 1}, ], - "reactions": [ - {"reactants": [0], "conditions": [1], "products": [2]} - ], + "reactions": [{"reactants": [0], "conditions": [1], "products": [2]}], } ], } diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 2f90988..5cbb137 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -32,9 +32,7 @@ def test_evaluator_metrics(tmp_path): {"id": 0, "smiles": "C"}, {"id": 1, "smiles": "CC"}, ], - "reactions": [ - {"reactants": [0], "products": [1], "conditions": []} - ], + "reactions": [{"reactants": [0], "products": [1], "conditions": []}], } ] dataset_dir = tmp_path / "dataset" @@ -47,9 +45,7 @@ def test_evaluator_metrics(tmp_path): {"id": 0, "smiles": "C"}, {"id": 1, "smiles": ""}, ], - "predicted_reactions": [ - {"reactants": [0], "products": [1], "conditions": []} - ], + "predicted_reactions": [{"reactants": [0], "products": [1], "conditions": []}], } ] pred_path = tmp_path / "preds.jsonl" @@ -89,9 +85,7 @@ def test_evaluator_handles_missing_predictions(tmp_path): {"id": 0, "smiles": "C"}, {"id": 1, "smiles": "CC"}, ], - "reactions": [ - {"reactants": [0], "products": [1], "conditions": []} - ], + "reactions": [{"reactants": [0], "products": [1], "conditions": []}], } ] dataset_dir = tmp_path / "dataset" diff --git a/tests/test_extractor.py b/tests/test_extractor.py index f78e185..c897cd4 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,7 +1,5 @@ """Tests for the MolMole extractor.""" -from __future__ import annotations - import io import json from pathlib import Path @@ -47,9 +45,7 @@ def fake_call_openai_model(*, image_bytes, **kwargs): # noqa: D401 {"id": 0, "smiles": "C"}, {"id": 1, "smiles": "CC"}, ], - "reactions": [ - {"reactants": [0], "products": [1], "conditions": []} - ], + "reactions": [{"reactants": [0], "products": [1], "conditions": []}], } ) @@ -80,7 +76,5 @@ def fake_call_openai_model(*, image_bytes, **kwargs): # noqa: D401 {"id": 0, "smiles": "C"}, {"id": 1, "smiles": "CC"}, ] - assert record["predicted_reactions"] == [ - {"reactants": [0], "products": [1], "conditions": []} - ] + assert record["predicted_reactions"] == [{"reactants": [0], "products": [1], "conditions": []}] assert "structures_expected" not in record