Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 168 additions & 2 deletions src/molmole_research/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@

import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Set

import datasets
from PIL import Image
import typer
from huggingface_hub import snapshot_download

try:
from rdkit import Chem # type: ignore[assignment]
except ImportError: # pragma: no cover - optional dependency
Chem = None # type: ignore[assignment]

app = typer.Typer(add_completion=False, help="Download the MolMole dataset")

Expand All @@ -39,6 +46,159 @@ def _extract_smiles(record: Dict[str, Any]) -> Optional[str]:
return None


def _unique_preserve_order(values: Iterable[str]) -> List[str]:
"""Return a list with duplicates removed while preserving order."""

seen: Set[str] = set()
result: List[str] = []
for value in values:
if value not in seen:
seen.add(value)
result.append(value)
return result


def _mol_to_smiles(mol_path: Path) -> Optional[str]:
"""Convert a MOL file into a SMILES string using RDKit if available."""

if Chem is None: # pragma: no cover - RDKit optional at runtime
return None
try:
mol = Chem.MolFromMolFile(str(mol_path)) # type: ignore[arg-type]
if mol is None:
return None
return Chem.MolToSmiles(mol, canonical=True) # type: ignore[call-arg]
except Exception: # pragma: no cover - RDKit parsing errors are rare
return None


def _integrate_additional_metadata(dataset: str, out: Path, labels: List[Dict[str, Any]]) -> None:
"""Augment labels with MolMole metadata if available.

The MolMole dataset repository contains additional JSON annotations with
bounding boxes and corresponding MOL files for each chemical structure.
This function downloads those assets (if present) and enriches the
``labels`` list with bounding boxes, SMILES strings derived from the MOL
files, and copies the raw assets into ``out``.
"""

try:
repo_root = Path(
snapshot_download(
dataset,
repo_type="dataset",
allow_patterns=["json/*", "mol/*"],
)
)
except Exception: # pragma: no cover - network/cache issues handled gracefully
return

json_dir = repo_root / "json"
if not json_dir.exists():
return
json_files = sorted(json_dir.glob("*.json"))
if not json_files:
return

annotations_path = json_files[0]
try:
annotations = json.loads(annotations_path.read_text(encoding="utf-8"))
except Exception:
return

images_meta = {
item.get("file_name"): item
for item in annotations.get("images", [])
if isinstance(item, dict) and item.get("file_name")
}

mol_repo_dir = repo_root / "mol"
mol_out_dir = out / "mol"
copied_mols: Set[str] = set()
structures_added = 0

for entry in labels:
fname = entry.get("file_name")
if not fname:
continue
meta = images_meta.get(fname)
if not meta:
continue

entry.setdefault("height", meta.get("height"))
entry.setdefault("width", meta.get("width"))

structures: List[Dict[str, Any]] = entry.get("structures", [])
texts: List[Dict[str, Any]] = entry.get("texts", [])

for bbox in meta.get("dla_bboxes", []):
if not isinstance(bbox, dict):
continue
category_id = bbox.get("category_id")
bbox_id = bbox.get("id")
bbox_coords = bbox.get("bbox")
if category_id == 1 and bbox_id is not None:
mol_name = f"{Path(fname).stem}_bbox_{bbox_id}.mol"
mol_repo_path = mol_repo_dir / mol_name
smiles = None
mol_relative = None
if mol_repo_path.exists():
mol_out_dir.mkdir(parents=True, exist_ok=True)
if mol_name not in copied_mols:
shutil.copy2(mol_repo_path, mol_out_dir / mol_name)
copied_mols.add(mol_name)
smiles = _mol_to_smiles(mol_out_dir / mol_name)
mol_relative = str(Path("mol") / mol_name)
structures.append(
{
"id": bbox_id,
"bbox": bbox_coords,
"category": "structure",
"mol_file": mol_relative,
"smiles": smiles,
}
)
structures_added += 1
if smiles:
smiles_list = entry.get("smiles_list", [])
if isinstance(smiles_list, list):
smiles_list.append(smiles)
entry["smiles_list"] = smiles_list
else:
entry["smiles_list"] = [smiles]
elif category_id == 2:
texts.append(
{
"id": bbox_id,
"bbox": bbox_coords,
"category": "text",
}
)

if structures:
entry["structures"] = structures
if texts:
entry["texts"] = texts

smiles_list = entry.get("smiles_list") or []
if isinstance(smiles_list, list):
entry["smiles_list"] = _unique_preserve_order(
[s for s in smiles_list if isinstance(s, str) and s.strip()]
)
if entry["smiles_list"] and not entry.get("smiles"):
entry["smiles"] = entry["smiles_list"][0]

# Copy the annotations JSON for reference
annotations_out = out / "annotations.json"
try:
shutil.copy2(annotations_path, annotations_out)
except Exception: # pragma: no cover - copying issues should not fail download
pass

if structures_added:
typer.echo(f"Integrated {structures_added} structures from repository metadata")


@app.command("run")
def download_dataset(
dataset: str = typer.Option(
Expand Down Expand Up @@ -100,7 +260,13 @@ def download_dataset(

# Extract SMILES (if available)
smiles = _extract_smiles(record)
labels.append({"file_name": fname, "smiles": smiles})
entry: Dict[str, Any] = {"file_name": fname, "smiles": smiles}
if smiles:
entry["smiles_list"] = [smiles]
labels.append(entry)

# Enrich labels with additional metadata (MolMole repository only)
_integrate_additional_metadata(dataset, out, labels)

# Save labels JSON
labels_path = out / "labels.json"
Expand Down
63 changes: 49 additions & 14 deletions src/molmole_research/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import json
from pathlib import Path
from typing import Dict, List
from typing import Dict, Iterable, List, Set

import typer

Expand Down Expand Up @@ -100,13 +100,51 @@ def evaluate(
typer.echo(f"Missing labels.json in {dataset_dir}; run the downloader first")
raise typer.Exit(1)
with labels_path.open("r", encoding="utf-8") as fh:
labels_list: List[Dict[str, str | None]] = json.load(fh)
labels: Dict[str, str] = {}
labels_list: List[Dict[str, object]] = json.load(fh)

def _collect_smiles(entry: Dict[str, object]) -> List[str]:
values: List[str] = []
raw_smiles = entry.get("smiles")
if isinstance(raw_smiles, str) and raw_smiles.strip():
values.append(raw_smiles.strip())
smiles_list = entry.get("smiles_list")
if isinstance(smiles_list, Iterable) and not isinstance(smiles_list, (str, bytes)):
values.extend(s.strip() for s in smiles_list if isinstance(s, str) and s.strip())
for struct in (
entry.get("structures", []) if isinstance(entry.get("structures"), list) else []
):
if isinstance(struct, dict):
struct_smiles = struct.get("smiles")
if isinstance(struct_smiles, str) and struct_smiles.strip():
values.append(struct_smiles.strip())
# Deduplicate while preserving order
seen: Set[str] = set()
unique: List[str] = []
for value in values:
if value not in seen:
seen.add(value)
unique.append(value)
return unique

labels: Dict[str, Dict[str, Set[str]]] = {}
for entry in labels_list:
fname = entry.get("file_name")
smiles = entry.get("smiles")
if fname and smiles:
labels[fname] = smiles
fname = entry.get("file_name") if isinstance(entry, dict) else None
if not isinstance(fname, str) or not fname:
continue
smiles_values = _collect_smiles(entry)
if not smiles_values:
continue
canonical_set: Set[str] = set()
inchi_set: Set[str] = set()
for smiles in smiles_values:
canonical = _canonical_smiles(smiles)
if canonical:
canonical_set.add(canonical)
inchi = _inchi_key(smiles)
if inchi:
inchi_set.add(inchi)
if canonical_set or inchi_set:
labels[fname] = {"canonical": canonical_set, "inchi": inchi_set}

total = 0
correct_smiles = 0
Expand All @@ -125,19 +163,16 @@ def evaluate(
pred_text = record.get("text")
if not fname or pred_text is None:
continue
gt_smiles = labels.get(fname)
if not gt_smiles:
gt_sets = labels.get(fname)
if not gt_sets:
continue
total += 1
# Canonicalise ground truth
gt_canonical = _canonical_smiles(gt_smiles)
gt_inchi = _inchi_key(gt_smiles)
# Canonicalise prediction
pred_canonical = _canonical_smiles(pred_text.strip()) if pred_text.strip() else None
pred_inchi = _inchi_key(pred_text.strip()) if pred_text.strip() else None
if gt_canonical and pred_canonical and gt_canonical == pred_canonical:
if pred_canonical and gt_sets["canonical"] and pred_canonical in gt_sets["canonical"]:
correct_smiles += 1
if gt_inchi and pred_inchi and gt_inchi == pred_inchi:
if pred_inchi and gt_sets["inchi"] and pred_inchi in gt_sets["inchi"]:
correct_inchi += 1

accuracy_smiles = correct_smiles / total if total else 0.0
Expand Down
75 changes: 73 additions & 2 deletions tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from __future__ import annotations

import json
from pathlib import Path

from PIL import Image
import datasets
from rdkit import Chem

from molmole_research.downloader import download_dataset

Expand All @@ -35,16 +37,70 @@ def _make_dummy_dataset(num_items: int = 2):
)


def _create_dummy_repo(repo_dir: Path) -> None:
"""Create a fake HuggingFace repository layout with annotations and MOL files."""

json_dir = repo_dir / "json"
mol_dir = repo_dir / "mol"
json_dir.mkdir(parents=True, exist_ok=True)
mol_dir.mkdir(parents=True, exist_ok=True)

annotations = {
"images": [
{
"file_name": "item_0.png",
"height": 16,
"width": 16,
"dla_bboxes": [
{"id": 0, "bbox": [0, 0, 4, 4], "category_id": 1},
{"id": 1, "bbox": [4, 4, 4, 4], "category_id": 2},
{"id": 2, "bbox": [8, 8, 4, 4], "category_id": 1},
],
"reactions": [],
},
{
"file_name": "item_1.png",
"height": 16,
"width": 16,
"dla_bboxes": [
{"id": 0, "bbox": [1, 1, 4, 4], "category_id": 1},
],
"reactions": [],
},
]
}
(json_dir / "annotations.json").write_text(json.dumps(annotations), encoding="utf-8")

smiles_map = {
"item_0_bbox_0.mol": "C",
"item_0_bbox_2.mol": "CC",
"item_1_bbox_0.mol": "O",
}
for name, smiles in smiles_map.items():
mol = Chem.MolFromSmiles(smiles)
assert mol is not None
block = Chem.MolToMolBlock(mol)
(mol_dir / name).write_text(block, encoding="utf-8")


def test_downloader_saves_images_and_labels(monkeypatch, tmp_path):
"""Verify that the downloader writes images and a labels.json file."""
dummy_ds = _make_dummy_dataset(3)

repo_dir = tmp_path / "hf_repo"
_create_dummy_repo(repo_dir)

# Patch datasets.load_dataset to return our dummy dataset regardless of the input
def fake_load_dataset(*args, **kwargs): # noqa: D401
return dummy_ds

monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset)

def fake_snapshot_download(*args, **kwargs):
return str(repo_dir)

monkeypatch.setattr("molmole_research.downloader.snapshot_download", fake_snapshot_download)

out_dir = tmp_path / "download"
download_dataset(dataset="dummy", split="train", out=out_dir)
# Check that images were saved
Expand All @@ -58,5 +114,20 @@ def fake_load_dataset(*args, **kwargs): # noqa: D401
assert labels_path.exists()
labels = json.loads(labels_path.read_text())
assert len(labels) == 3
for entry in labels:
assert entry["smiles"] == "C"

item0 = next(entry for entry in labels if entry["file_name"] == "item_0.png")
assert item0["smiles"] == "C"
assert item0["smiles_list"] == ["C", "CC"]
assert len(item0["structures"]) == 2
for struct in item0["structures"]:
assert struct["mol_file"].startswith("mol/")
assert (out_dir / struct["mol_file"]).exists()

item1 = next(entry for entry in labels if entry["file_name"] == "item_1.png")
assert item1["smiles"] == "C"
# SMILES list should include dataset smile plus MOL-derived value
assert item1["smiles_list"] == ["C", "O"]
assert len(item1["structures"]) == 1

# Ensure annotations JSON is copied over
assert (out_dir / "annotations.json").exists()
Loading