Skip to content

Commit

Permalink
Made mol prop dataloader more modular and
Browse files Browse the repository at this point in the history
adaptable, added tests.
  • Loading branch information
leojklarner committed Nov 2, 2023
1 parent 8a1d9a8 commit cd2b86b
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 102 deletions.
122 changes: 122 additions & 0 deletions Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c5871038-462e-4001-9555-668be272e460",
"metadata": {},
"outputs": [],
"source": [
"from gauche.dataloader import MolPropLoader"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2232be68-ac1a-46c3-85db-9a4901693ccc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/mnt/c/Users/lklar/OneDrive - Nexus365/SABS/DPhil/Publications/gauche/gauche/gauche/dataloader/mol_prop.py\n",
"/mnt/c/Users/lklar/OneDrive - Nexus365/SABS/DPhil/Publications/gauche/gauche/data/property_prediction/Photoswitch.csv\n",
"Found invalid labels [nan nan nan nan nan nan nan nan nan nan nan nan nan] at indices [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 158]\n"
]
}
],
"source": [
"dl = MolPropLoader()\n",
"dl.load_benchmark(\"Photoswitch\")\n",
"dl.featurize(\"molecular_graphs\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4d8896c0-9e74-414c-a92b-6d009f2f6b3a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import networkx as nx\n",
"nx.to_numpy_array(dl.features[0])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "36a65b60-7a4c-4aac-bfdc-3121c5d5b0e0",
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'Graph' object has no attribute 'adjacency_matrix'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dl\u001b[38;5;241m.\u001b[39mfeatures[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39madjacency_matrix\n",
"\u001b[0;31mAttributeError\u001b[0m: 'Graph' object has no attribute 'adjacency_matrix'"
]
}
],
"source": [
"adjacency_matrixdl.features[0].adjacency_matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95720543-2790-4471-9e76-003f2522b169",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "gauche",
"language": "python",
"name": "gauche"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
22 changes: 16 additions & 6 deletions gauche/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
type validation and feature extraction functionalities.
"""

from typing import Optional
from abc import ABCMeta, abstractmethod

from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -34,27 +35,36 @@ def labels(self, value):
raise NotImplementedError

@abstractmethod
def validate(self, drop=True):
"""Checks whether the loaded data is a valid instance of the specified
data type, potentially dropping invalid entries.
def validate(
self, drop: Optional[bool] = True, canonicalize: Optional[bool] = True
):
"""Checks whether the loaded data is a valid instance
of the specified data type, optionally dropping invalid
entries and standardizing the remaining ones.
:param drop: whether to drop invalid entries
:type drop: bool
:param canonicalize: whether to standardize the data
:type canonicalize: bool
"""
raise NotImplementedError

@abstractmethod
def featurize(self, representation):
def featurize(self, representation: str, **kwargs):
"""Transforms the features to the specified representation (in-place).
:param representation: desired feature format
:type representation: str
:param kwargs: additional keyword arguments for the representation function
:type kwargs: dict
"""
raise NotImplementedError

def split_and_scale(
self, test_size=0.2, scale_labels=True, scale_features=False
self,
test_size: int = 0.2,
scale_labels: bool = True,
scale_features: bool = False,
):
"""Splits the data into training and testing sets.
Expand Down
Loading

0 comments on commit cd2b86b

Please sign in to comment.