From b62b2f5bf08c4f37071a0245dbf27b4df6b7e61a Mon Sep 17 00:00:00 2001 From: Dmitry Date: Tue, 28 Oct 2025 09:35:27 +0100 Subject: [PATCH 1/2] Enhance MolMole downloader with metadata integration --- src/molmole_research/downloader.py | 172 ++++++++++++++++++++++++++++- src/molmole_research/evaluator.py | 69 +++++++++--- tests/test_downloader.py | 75 ++++++++++++- tests/test_evaluator.py | 34 +++++- 4 files changed, 330 insertions(+), 20 deletions(-) diff --git a/src/molmole_research/downloader.py b/src/molmole_research/downloader.py index ed194ae..7fbe1ca 100644 --- a/src/molmole_research/downloader.py +++ b/src/molmole_research/downloader.py @@ -15,12 +15,19 @@ import json import os +import shutil from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set import datasets from PIL import Image import typer +from huggingface_hub import snapshot_download + +try: + from rdkit import Chem # type: ignore[assignment] +except ImportError: # pragma: no cover - optional dependency + Chem = None # type: ignore[assignment] app = typer.Typer(add_completion=False, help="Download the MolMole dataset") @@ -39,6 +46,161 @@ def _extract_smiles(record: Dict[str, Any]) -> Optional[str]: return None +def _unique_preserve_order(values: Iterable[str]) -> List[str]: + """Return a list with duplicates removed while preserving order.""" + + seen: Set[str] = set() + result: List[str] = [] + for value in values: + if value not in seen: + seen.add(value) + result.append(value) + return result + + +def _mol_to_smiles(mol_path: Path) -> Optional[str]: + """Convert a MOL file into a SMILES string using RDKit if available.""" + + if Chem is None: # pragma: no cover - RDKit optional at runtime + return None + try: + mol = Chem.MolFromMolFile(str(mol_path)) # type: ignore[arg-type] + if mol is None: + return None + return Chem.MolToSmiles(mol, canonical=True) # type: ignore[call-arg] + except Exception: # pragma: no cover - RDKit parsing errors are rare + return None + + +def _integrate_additional_metadata( + dataset: str, out: Path, labels: List[Dict[str, Any]] +) -> None: + """Augment labels with MolMole metadata if available. + + The MolMole dataset repository contains additional JSON annotations with + bounding boxes and corresponding MOL files for each chemical structure. + This function downloads those assets (if present) and enriches the + ``labels`` list with bounding boxes, SMILES strings derived from the MOL + files, and copies the raw assets into ``out``. + """ + + try: + repo_root = Path( + snapshot_download( + dataset, + repo_type="dataset", + allow_patterns=["json/*", "mol/*"], + ) + ) + except Exception: # pragma: no cover - network/cache issues handled gracefully + return + + json_dir = repo_root / "json" + if not json_dir.exists(): + return + json_files = sorted(json_dir.glob("*.json")) + if not json_files: + return + + annotations_path = json_files[0] + try: + annotations = json.loads(annotations_path.read_text(encoding="utf-8")) + except Exception: + return + + images_meta = { + item.get("file_name"): item + for item in annotations.get("images", []) + if isinstance(item, dict) and item.get("file_name") + } + + mol_repo_dir = repo_root / "mol" + mol_out_dir = out / "mol" + copied_mols: Set[str] = set() + structures_added = 0 + + for entry in labels: + fname = entry.get("file_name") + if not fname: + continue + meta = images_meta.get(fname) + if not meta: + continue + + entry.setdefault("height", meta.get("height")) + entry.setdefault("width", meta.get("width")) + + structures: List[Dict[str, Any]] = entry.get("structures", []) + texts: List[Dict[str, Any]] = entry.get("texts", []) + + for bbox in meta.get("dla_bboxes", []): + if not isinstance(bbox, dict): + continue + category_id = bbox.get("category_id") + bbox_id = bbox.get("id") + bbox_coords = bbox.get("bbox") + if category_id == 1 and bbox_id is not None: + mol_name = f"{Path(fname).stem}_bbox_{bbox_id}.mol" + mol_repo_path = mol_repo_dir / mol_name + smiles = None + mol_relative = None + if mol_repo_path.exists(): + mol_out_dir.mkdir(parents=True, exist_ok=True) + if mol_name not in copied_mols: + shutil.copy2(mol_repo_path, mol_out_dir / mol_name) + copied_mols.add(mol_name) + smiles = _mol_to_smiles(mol_out_dir / mol_name) + mol_relative = str(Path("mol") / mol_name) + structures.append( + { + "id": bbox_id, + "bbox": bbox_coords, + "category": "structure", + "mol_file": mol_relative, + "smiles": smiles, + } + ) + structures_added += 1 + if smiles: + smiles_list = entry.get("smiles_list", []) + if isinstance(smiles_list, list): + smiles_list.append(smiles) + entry["smiles_list"] = smiles_list + else: + entry["smiles_list"] = [smiles] + elif category_id == 2: + texts.append( + { + "id": bbox_id, + "bbox": bbox_coords, + "category": "text", + } + ) + + if structures: + entry["structures"] = structures + if texts: + entry["texts"] = texts + + smiles_list = entry.get("smiles_list") or [] + if isinstance(smiles_list, list): + entry["smiles_list"] = _unique_preserve_order( + [s for s in smiles_list if isinstance(s, str) and s.strip()] + ) + if entry["smiles_list"] and not entry.get("smiles"): + entry["smiles"] = entry["smiles_list"][0] + + # Copy the annotations JSON for reference + annotations_out = out / "annotations.json" + try: + shutil.copy2(annotations_path, annotations_out) + except Exception: # pragma: no cover - copying issues should not fail download + pass + + if structures_added: + typer.echo(f"Integrated {structures_added} structures from repository metadata") + + @app.command("run") def download_dataset( dataset: str = typer.Option( @@ -100,7 +262,13 @@ def download_dataset( # Extract SMILES (if available) smiles = _extract_smiles(record) - labels.append({"file_name": fname, "smiles": smiles}) + entry: Dict[str, Any] = {"file_name": fname, "smiles": smiles} + if smiles: + entry["smiles_list"] = [smiles] + labels.append(entry) + + # Enrich labels with additional metadata (MolMole repository only) + _integrate_additional_metadata(dataset, out, labels) # Save labels JSON labels_path = out / "labels.json" diff --git a/src/molmole_research/evaluator.py b/src/molmole_research/evaluator.py index 395e0bc..5b6d8f2 100644 --- a/src/molmole_research/evaluator.py +++ b/src/molmole_research/evaluator.py @@ -17,7 +17,7 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, Iterable, List, Set import typer @@ -100,13 +100,53 @@ def evaluate( typer.echo(f"Missing labels.json in {dataset_dir}; run the downloader first") raise typer.Exit(1) with labels_path.open("r", encoding="utf-8") as fh: - labels_list: List[Dict[str, str | None]] = json.load(fh) - labels: Dict[str, str] = {} + labels_list: List[Dict[str, object]] = json.load(fh) + + def _collect_smiles(entry: Dict[str, object]) -> List[str]: + values: List[str] = [] + raw_smiles = entry.get("smiles") + if isinstance(raw_smiles, str) and raw_smiles.strip(): + values.append(raw_smiles.strip()) + smiles_list = entry.get("smiles_list") + if isinstance(smiles_list, Iterable) and not isinstance(smiles_list, (str, bytes)): + values.extend( + s.strip() + for s in smiles_list + if isinstance(s, str) and s.strip() + ) + for struct in entry.get("structures", []) if isinstance(entry.get("structures"), list) else []: + if isinstance(struct, dict): + struct_smiles = struct.get("smiles") + if isinstance(struct_smiles, str) and struct_smiles.strip(): + values.append(struct_smiles.strip()) + # Deduplicate while preserving order + seen: Set[str] = set() + unique: List[str] = [] + for value in values: + if value not in seen: + seen.add(value) + unique.append(value) + return unique + + labels: Dict[str, Dict[str, Set[str]]] = {} for entry in labels_list: - fname = entry.get("file_name") - smiles = entry.get("smiles") - if fname and smiles: - labels[fname] = smiles + fname = entry.get("file_name") if isinstance(entry, dict) else None + if not isinstance(fname, str) or not fname: + continue + smiles_values = _collect_smiles(entry) + if not smiles_values: + continue + canonical_set: Set[str] = set() + inchi_set: Set[str] = set() + for smiles in smiles_values: + canonical = _canonical_smiles(smiles) + if canonical: + canonical_set.add(canonical) + inchi = _inchi_key(smiles) + if inchi: + inchi_set.add(inchi) + if canonical_set or inchi_set: + labels[fname] = {"canonical": canonical_set, "inchi": inchi_set} total = 0 correct_smiles = 0 @@ -125,19 +165,20 @@ def evaluate( pred_text = record.get("text") if not fname or pred_text is None: continue - gt_smiles = labels.get(fname) - if not gt_smiles: + gt_sets = labels.get(fname) + if not gt_sets: continue total += 1 - # Canonicalise ground truth - gt_canonical = _canonical_smiles(gt_smiles) - gt_inchi = _inchi_key(gt_smiles) # Canonicalise prediction pred_canonical = _canonical_smiles(pred_text.strip()) if pred_text.strip() else None pred_inchi = _inchi_key(pred_text.strip()) if pred_text.strip() else None - if gt_canonical and pred_canonical and gt_canonical == pred_canonical: + if ( + pred_canonical + and gt_sets["canonical"] + and pred_canonical in gt_sets["canonical"] + ): correct_smiles += 1 - if gt_inchi and pred_inchi and gt_inchi == pred_inchi: + if pred_inchi and gt_sets["inchi"] and pred_inchi in gt_sets["inchi"]: correct_inchi += 1 accuracy_smiles = correct_smiles / total if total else 0.0 diff --git a/tests/test_downloader.py b/tests/test_downloader.py index ab3b8f0..770ed83 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -9,9 +9,11 @@ from __future__ import annotations import json +from pathlib import Path from PIL import Image import datasets +from rdkit import Chem from molmole_research.downloader import download_dataset @@ -35,16 +37,70 @@ def _make_dummy_dataset(num_items: int = 2): ) +def _create_dummy_repo(repo_dir: Path) -> None: + """Create a fake HuggingFace repository layout with annotations and MOL files.""" + + json_dir = repo_dir / "json" + mol_dir = repo_dir / "mol" + json_dir.mkdir(parents=True, exist_ok=True) + mol_dir.mkdir(parents=True, exist_ok=True) + + annotations = { + "images": [ + { + "file_name": "item_0.png", + "height": 16, + "width": 16, + "dla_bboxes": [ + {"id": 0, "bbox": [0, 0, 4, 4], "category_id": 1}, + {"id": 1, "bbox": [4, 4, 4, 4], "category_id": 2}, + {"id": 2, "bbox": [8, 8, 4, 4], "category_id": 1}, + ], + "reactions": [], + }, + { + "file_name": "item_1.png", + "height": 16, + "width": 16, + "dla_bboxes": [ + {"id": 0, "bbox": [1, 1, 4, 4], "category_id": 1}, + ], + "reactions": [], + }, + ] + } + (json_dir / "annotations.json").write_text(json.dumps(annotations), encoding="utf-8") + + smiles_map = { + "item_0_bbox_0.mol": "C", + "item_0_bbox_2.mol": "CC", + "item_1_bbox_0.mol": "O", + } + for name, smiles in smiles_map.items(): + mol = Chem.MolFromSmiles(smiles) + assert mol is not None + block = Chem.MolToMolBlock(mol) + (mol_dir / name).write_text(block, encoding="utf-8") + + def test_downloader_saves_images_and_labels(monkeypatch, tmp_path): """Verify that the downloader writes images and a labels.json file.""" dummy_ds = _make_dummy_dataset(3) + repo_dir = tmp_path / "hf_repo" + _create_dummy_repo(repo_dir) + # Patch datasets.load_dataset to return our dummy dataset regardless of the input def fake_load_dataset(*args, **kwargs): # noqa: D401 return dummy_ds monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) + def fake_snapshot_download(*args, **kwargs): + return str(repo_dir) + + monkeypatch.setattr("molmole_research.downloader.snapshot_download", fake_snapshot_download) + out_dir = tmp_path / "download" download_dataset(dataset="dummy", split="train", out=out_dir) # Check that images were saved @@ -58,5 +114,20 @@ def fake_load_dataset(*args, **kwargs): # noqa: D401 assert labels_path.exists() labels = json.loads(labels_path.read_text()) assert len(labels) == 3 - for entry in labels: - assert entry["smiles"] == "C" + + item0 = next(entry for entry in labels if entry["file_name"] == "item_0.png") + assert item0["smiles"] == "C" + assert item0["smiles_list"] == ["C", "CC"] + assert len(item0["structures"]) == 2 + for struct in item0["structures"]: + assert struct["mol_file"].startswith("mol/") + assert (out_dir / struct["mol_file"]).exists() + + item1 = next(entry for entry in labels if entry["file_name"] == "item_1.png") + assert item1["smiles"] == "C" + # SMILES list should include dataset smile plus MOL-derived value + assert item1["smiles_list"] == ["C", "O"] + assert len(item1["structures"]) == 1 + + # Ensure annotations JSON is copied over + assert (out_dir / "annotations.json").exists() diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 31b05b6..c28e540 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -30,8 +30,8 @@ def test_evaluator_accuracy(tmp_path): """Evaluate a small set of predictions and verify the computed accuracy.""" # Ground truth labels labels = [ - {"file_name": "img_0.png", "smiles": "C"}, - {"file_name": "img_1.png", "smiles": "CC"}, + {"file_name": "img_0.png", "smiles": "C", "smiles_list": ["C"]}, + {"file_name": "img_1.png", "smiles_list": ["CC"]}, ] dataset_dir = tmp_path / "data" _write_labels(dataset_dir, labels) @@ -57,3 +57,33 @@ def test_evaluator_accuracy(tmp_path): assert metrics["correct_smiles"] == 1 assert metrics["correct_inchi"] == 1 assert abs(metrics["accuracy_smiles"] - 0.5) < 1e-6 + + +def test_evaluator_handles_multiple_smiles(tmp_path): + """Predictions should match if any ground truth SMILES is correct.""" + + labels = [ + { + "file_name": "img_multi.png", + "smiles": None, + "smiles_list": ["CC", "O"], + "structures": [{"smiles": "CC"}, {"smiles": "O"}], + } + ] + dataset_dir = tmp_path / "data" + _write_labels(dataset_dir, labels) + + preds = [{"file_name": "img_multi.png", "text": "O"}] + pred_path = tmp_path / "preds.jsonl" + _write_predictions(pred_path, preds) + + results_dir = tmp_path / "results" + evaluator.evaluate(pred=pred_path, dataset_dir=dataset_dir, out=results_dir) + + metrics_files = list(results_dir.iterdir()) + assert metrics_files + with metrics_files[0].open("r", encoding="utf-8") as fh: + metrics = json.load(fh) + assert metrics["total"] == 1 + assert metrics["correct_smiles"] == 1 + assert metrics["correct_inchi"] == 1 From d9fac73eeb093df26bf273533b35f4e05edfc6fd Mon Sep 17 00:00:00 2001 From: Dmitry Date: Tue, 28 Oct 2025 13:39:04 +0100 Subject: [PATCH 2/2] Format MolMole modules --- src/molmole_research/downloader.py | 4 +--- src/molmole_research/evaluator.py | 16 +++++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/molmole_research/downloader.py b/src/molmole_research/downloader.py index 7fbe1ca..80209d4 100644 --- a/src/molmole_research/downloader.py +++ b/src/molmole_research/downloader.py @@ -72,9 +72,7 @@ def _mol_to_smiles(mol_path: Path) -> Optional[str]: return None -def _integrate_additional_metadata( - dataset: str, out: Path, labels: List[Dict[str, Any]] -) -> None: +def _integrate_additional_metadata(dataset: str, out: Path, labels: List[Dict[str, Any]]) -> None: """Augment labels with MolMole metadata if available. The MolMole dataset repository contains additional JSON annotations with diff --git a/src/molmole_research/evaluator.py b/src/molmole_research/evaluator.py index 5b6d8f2..5eea592 100644 --- a/src/molmole_research/evaluator.py +++ b/src/molmole_research/evaluator.py @@ -109,12 +109,10 @@ def _collect_smiles(entry: Dict[str, object]) -> List[str]: values.append(raw_smiles.strip()) smiles_list = entry.get("smiles_list") if isinstance(smiles_list, Iterable) and not isinstance(smiles_list, (str, bytes)): - values.extend( - s.strip() - for s in smiles_list - if isinstance(s, str) and s.strip() - ) - for struct in entry.get("structures", []) if isinstance(entry.get("structures"), list) else []: + values.extend(s.strip() for s in smiles_list if isinstance(s, str) and s.strip()) + for struct in ( + entry.get("structures", []) if isinstance(entry.get("structures"), list) else [] + ): if isinstance(struct, dict): struct_smiles = struct.get("smiles") if isinstance(struct_smiles, str) and struct_smiles.strip(): @@ -172,11 +170,7 @@ def _collect_smiles(entry: Dict[str, object]) -> List[str]: # Canonicalise prediction pred_canonical = _canonical_smiles(pred_text.strip()) if pred_text.strip() else None pred_inchi = _inchi_key(pred_text.strip()) if pred_text.strip() else None - if ( - pred_canonical - and gt_sets["canonical"] - and pred_canonical in gt_sets["canonical"] - ): + if pred_canonical and gt_sets["canonical"] and pred_canonical in gt_sets["canonical"]: correct_smiles += 1 if pred_inchi and gt_sets["inchi"] and pred_inchi in gt_sets["inchi"]: correct_inchi += 1