From b9e4cd6cbfda323f310e9c919903c9e41b3f1e39 Mon Sep 17 00:00:00 2001 From: jyyy6565 <100702463+jyyy6565@users.noreply.github.com> Date: Sun, 21 Jan 2024 18:41:56 +0800 Subject: [PATCH] [Dataset] Add dataset (#192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 2.1 * datasets修改 给datasets中的所有属性添加force_reload参数 * remove files * update * update * 2.2修改 --- gammagl/data/dataset.py | 14 +- gammagl/data/in_memory_dataset.py | 8 +- gammagl/datasets/__init__.py | 4 +- gammagl/datasets/alircd.py | 4 +- gammagl/datasets/amazon.py | 7 +- gammagl/datasets/aminer.py | 7 +- gammagl/datasets/coauthor.py | 7 +- gammagl/datasets/dblp.py | 6 +- gammagl/datasets/entities.py | 6 +- gammagl/datasets/flickr.py | 6 +- gammagl/datasets/hgb.py | 7 +- gammagl/datasets/imdb.py | 6 +- gammagl/datasets/ml.py | 4 +- gammagl/datasets/modelnet40.py | 6 +- gammagl/datasets/molecule_net.py | 218 ++++++++++++++++++++++++++ gammagl/datasets/planetoid.py | 7 +- gammagl/datasets/polblogs.py | 6 +- gammagl/datasets/ppi.py | 4 +- gammagl/datasets/reddit.py | 6 +- gammagl/datasets/shapenet.py | 6 +- gammagl/datasets/tu_dataset.py | 7 +- gammagl/datasets/webkb.py | 7 +- gammagl/datasets/wikics.py | 50 +++--- gammagl/datasets/wikipedia_network.py | 7 +- gammagl/datasets/zinc.py | 7 +- gammagl/utils/__init__.py | 4 +- gammagl/utils/convert.py | 2 +- gammagl/utils/smiles.py | 144 +++++++++++++++++ 28 files changed, 495 insertions(+), 72 deletions(-) create mode 100644 gammagl/datasets/molecule_net.py create mode 100644 gammagl/utils/smiles.py diff --git a/gammagl/data/dataset.py b/gammagl/data/dataset.py index b3d5cdc3..55c9bb0b 100644 --- a/gammagl/data/dataset.py +++ b/gammagl/data/dataset.py @@ -48,6 +48,8 @@ class Dataset(Dataset): :obj:`gammagl.data.Graph` object and returns a boolean value, indicating whether the graph object should be included in the final dataset. (default: :obj:`None`) + force_reload: bool, optional + Whether to re-process the dataset.(default: :obj:`False`) """ @@ -82,7 +84,8 @@ def get(self, idx: int) -> Graph: def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None): + pre_filter: Optional[Callable] = None, + force_reload: bool = False): super().__init__() self.raw_root = root @@ -98,6 +101,7 @@ def __init__(self, root: Optional[str] = None, self.pre_transform = pre_transform self.pre_filter = pre_filter self._indices: Optional[Sequence] = None + self.force_reload = force_reload # when finished,record dataset path to .ggl/datasets.json # next time will use this dataset to avoid download repeatedly. @@ -316,7 +320,8 @@ def _process(self): f"The `pre_transform` argument differs from the one used in " f"the pre-processed version of this dataset. If you want to " f"make use of another pre-processing technique, make sure to " - f"sure to delete '{self.processed_dir}' first") + f"sure to delete '{self.processed_dir}' first" + f"`force_reload=True` explicitly to reload the dataset.") f = osp.join(self.processed_dir, tlx.BACKEND + '_pre_filter.pt') if osp.exists(f) and self.load_with_pickle(f) != _repr(self.pre_filter): @@ -324,9 +329,10 @@ def _process(self): "The `pre_filter` argument differs from the one used in the " "pre-processed version of this dataset. If you want to make " "use of another pre-fitering technique, make sure to delete " - "'{self.processed_dir}' first") + "'{self.processed_dir}' first" + "`force_reload=True` explicitly to reload the dataset.") - if files_exist(self.processed_paths): # pragma: no cover + if not self.force_reload and files_exist(self.processed_paths): # pragma: no cover # self.process() return diff --git a/gammagl/data/in_memory_dataset.py b/gammagl/data/in_memory_dataset.py index 8bcdfbad..a5b51d62 100644 --- a/gammagl/data/in_memory_dataset.py +++ b/gammagl/data/in_memory_dataset.py @@ -35,6 +35,8 @@ class InMemoryDataset(Dataset): :obj:`gammagl.data.Graph` object and returns a boolean value, indicating whether the graph object should be included in the final dataset. (default: :obj:`None`) + force_reload: bool, optional + Whether to re-process the dataset.(default: :obj:`False`) """ @property @@ -54,12 +56,14 @@ def process(self): def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None): - super().__init__(root, transform, pre_transform, pre_filter) + pre_filter: Optional[Callable] = None, + force_reload: bool = False): + super().__init__(root, transform, pre_transform, pre_filter, force_reload) self.data = None self.slices = None self._data_list: Optional[List[Graph]] = None + @property def num_classes(self) -> int: r"""Returns the number of classes in the dataset.""" diff --git a/gammagl/datasets/__init__.py b/gammagl/datasets/__init__.py index dd416877..8954e8d4 100644 --- a/gammagl/datasets/__init__.py +++ b/gammagl/datasets/__init__.py @@ -17,6 +17,7 @@ from .aminer import AMiner from .polblogs import PolBlogs from .wikics import WikiCS +from .molecule_net import MoleculeNet __all__ = [ 'Amazon', @@ -37,7 +38,8 @@ 'ZINC', 'AMiner', 'PolBlogs', - 'WikiCS' + 'WikiCS', + 'MoleculeNet' ] classes = __all__ diff --git a/gammagl/datasets/alircd.py b/gammagl/datasets/alircd.py index aedfb940..b2ec7068 100644 --- a/gammagl/datasets/alircd.py +++ b/gammagl/datasets/alircd.py @@ -14,10 +14,10 @@ class AliRCD(InMemoryDataset): """ url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/AliRCD_session1.zip" - def __init__(self, root=None, transform=None, pre_transform=None): + def __init__(self, root=None, transform=None, pre_transform=None, force_reload: bool = False): self.edge_size = 157814864 self.node_size = 13806619 - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/amazon.py b/gammagl/datasets/amazon.py index 0cafc3e1..bb0e4907 100644 --- a/gammagl/datasets/amazon.py +++ b/gammagl/datasets/amazon.py @@ -31,6 +31,8 @@ class Amazon(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) Stats: .. list-table:: @@ -58,10 +60,11 @@ class Amazon(InMemoryDataset): def __init__(self, root: str = None, name: str = 'computers', transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, + force_reload: bool = False): self.name = name.lower() assert self.name in ['computers', 'photo'] - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/aminer.py b/gammagl/datasets/aminer.py index cf6aa0ba..d80aba8c 100644 --- a/gammagl/datasets/aminer.py +++ b/gammagl/datasets/aminer.py @@ -41,6 +41,8 @@ class AMiner(InMemoryDataset): an :obj:`gammagl.data.HeteroGraph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ @@ -48,8 +50,9 @@ class AMiner(InMemoryDataset): y_url = 'https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1' def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): - super().__init__(root, transform, pre_transform, pre_filter) + pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, + force_reload: bool = False): + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/coauthor.py b/gammagl/datasets/coauthor.py index daa11d25..c842bb95 100644 --- a/gammagl/datasets/coauthor.py +++ b/gammagl/datasets/coauthor.py @@ -31,6 +31,8 @@ class Coauthor(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) Stats: .. list-table:: @@ -58,10 +60,11 @@ class Coauthor(InMemoryDataset): def __init__(self, root: str = None, name: str = 'cs', transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, + force_reload: bool = False): assert name.lower() in ['cs', 'physics'] self.name = 'CS' if name.lower() == 'cs' else 'Physics' - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/dblp.py b/gammagl/datasets/dblp.py index 54b25b24..6e4d4cb2 100644 --- a/gammagl/datasets/dblp.py +++ b/gammagl/datasets/dblp.py @@ -40,14 +40,16 @@ class DBLP(InMemoryDataset): an :obj:`gammagl.data.HeteroGraph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1' def __init__(self, root: str = None, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): - super().__init__(root, transform, pre_transform) + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/entities.py b/gammagl/datasets/entities.py index 0b68d82c..b398f831 100644 --- a/gammagl/datasets/entities.py +++ b/gammagl/datasets/entities.py @@ -42,6 +42,8 @@ class Entities(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ @@ -49,11 +51,11 @@ class Entities(InMemoryDataset): def __init__(self, root: str = None, name: str = 'aifb', hetero: bool = False, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, force_reload: bool = False): self.name = name.lower() self.hetero = hetero assert self.name in ['aifb', 'am', 'mutag', 'bgs'] - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/flickr.py b/gammagl/datasets/flickr.py index 166a3d24..783f3164 100644 --- a/gammagl/datasets/flickr.py +++ b/gammagl/datasets/flickr.py @@ -27,6 +27,8 @@ class Flickr(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) Tip --- @@ -51,8 +53,8 @@ class Flickr(InMemoryDataset): role_id = '1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7' def __init__(self, root: str = None, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): - super().__init__(root, transform, pre_transform) + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/hgb.py b/gammagl/datasets/hgb.py index 0b996f05..1e01f8ad 100644 --- a/gammagl/datasets/hgb.py +++ b/gammagl/datasets/hgb.py @@ -77,6 +77,8 @@ class HGBDataset(InMemoryDataset): an :class:`gammmgl.transform` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ @@ -100,10 +102,11 @@ class HGBDataset(InMemoryDataset): def __init__(self, root: str = None, name: str = 'acm', transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, + force_reload: bool = False): self.name = name.lower() assert self.name in set(self.names.keys()) - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/imdb.py b/gammagl/datasets/imdb.py index dafd5c33..afda4b80 100644 --- a/gammagl/datasets/imdb.py +++ b/gammagl/datasets/imdb.py @@ -36,14 +36,16 @@ class IMDB(InMemoryDataset): an :obj:`gammagl.data.HeteroGraph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ url = 'https://www.dropbox.com/s/g0btk9ctr1es39x/IMDB_processed.zip?dl=1' def __init__(self, root: str = None, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): - super().__init__(root, transform, pre_transform) + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/ml.py b/gammagl/datasets/ml.py index d1695a04..00d74da7 100644 --- a/gammagl/datasets/ml.py +++ b/gammagl/datasets/ml.py @@ -11,7 +11,7 @@ class MLDataset(InMemoryDataset): # url = 'https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-100k.zip' def __init__(self, root=None, split='train', transform=None, pre_transform=None, - pre_filter=None, dataset_name='ml-100k'): + pre_filter=None, dataset_name='ml-100k', force_reload: bool = False): assert split in ['train', 'val', 'valid', 'test'] assert dataset_name in ['ml-100k', 'ml-1m', 'ml-10m', 'ml-20m'] @@ -21,7 +21,7 @@ def __init__(self, root=None, split='train', transform=None, pre_transform=None, url_post = '.zip' self.url = f'{url_pre}{self.dataset_name}{url_post}' - super().__init__(root, transform, pre_transform, pre_filter) + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) diff --git a/gammagl/datasets/modelnet40.py b/gammagl/datasets/modelnet40.py index ce393545..1cfcca71 100644 --- a/gammagl/datasets/modelnet40.py +++ b/gammagl/datasets/modelnet40.py @@ -34,14 +34,16 @@ class ModelNet40(InMemoryDataset): (:obj:`"train"`, :obj:`"test"`). num_points: int, optional The number of points used to train or test. + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ url = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' - def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, split='train', num_points=1024): + def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, split='train', num_points=1024, force_reload: bool = False): self.num_points = num_points self.split = split - super(ModelNet40, self).__init__(root, transform, pre_transform, pre_filter) + super(ModelNet40, self).__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) assert split in ['train', 'test'] path = self.processed_paths[0] if split == 'train' else self.processed_paths[1] self.data, self.slices = self.load_data(path) diff --git a/gammagl/datasets/molecule_net.py b/gammagl/datasets/molecule_net.py new file mode 100644 index 00000000..20c01818 --- /dev/null +++ b/gammagl/datasets/molecule_net.py @@ -0,0 +1,218 @@ +import os +import os.path as osp +import re +from typing import Callable, Dict, Optional, Tuple, Union + +from gammagl.data import download_url +from gammagl.data import InMemoryDataset +from gammagl.data.extract import extract_gz +import tensorlayerx as tlx + +# from GammaGL.gammagl.utils.smiles import from_smiles +from ..utils.smiles import from_smiles + +class MoleculeNet(InMemoryDataset): + r"""The `MoleculeNet `_ benchmark + collection from the `"MoleculeNet: A Benchmark for Molecular Machine + Learning" `_ paper, containing datasets + from physical chemistry, biophysics and physiology. + All datasets come with the additional node and edge features introduced by + the :ogb:`null` + `Open Graph Benchmark `_. + + Args: + root (str): Root directory where the dataset should be saved. + name (str): The name of the dataset (:obj:`"ESOL"`, :obj:`"FreeSolv"`, + :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`, :obj:`"HIV"`, + :obj:`"BACE"`, :obj:`"BBBP"`, :obj:`"Tox21"`, :obj:`"ToxCast"`, + :obj:`"SIDER"`, :obj:`"ClinTox"`). + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + **STATS:** + + .. list-table:: + :widths: 20 10 10 10 10 10 + :header-rows: 1 + + * - Name + - #graphs + - #nodes + - #edges + - #features + - #classes + * - ESOL + - 1,128 + - ~13.3 + - ~27.4 + - 9 + - 1 + * - FreeSolv + - 642 + - ~8.7 + - ~16.8 + - 9 + - 1 + * - Lipophilicity + - 4,200 + - ~27.0 + - ~59.0 + - 9 + - 1 + * - PCBA + - 437,929 + - ~26.0 + - ~56.2 + - 9 + - 128 + * - MUV + - 93,087 + - ~24.2 + - ~52.6 + - 9 + - 17 + * - HIV + - 41,127 + - ~25.5 + - ~54.9 + - 9 + - 1 + * - BACE + - 1513 + - ~34.1 + - ~73.7 + - 9 + - 1 + * - BBBP + - 2,050 + - ~23.9 + - ~51.6 + - 9 + - 1 + * - Tox21 + - 7,831 + - ~18.6 + - ~38.6 + - 9 + - 12 + * - ToxCast + - 8,597 + - ~18.7 + - ~38.4 + - 9 + - 617 + * - SIDER + - 1,427 + - ~33.6 + - ~70.7 + - 9 + - 27 + * - ClinTox + - 1,484 + - ~26.1 + - ~55.5 + - 9 + - 2 + """ + + url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}' + + # Format: name: (display_name, url_name, csv_name, smiles_idx, y_idx) + names: Dict[str, Tuple[str, str, str, int, Union[int, slice]]] = { + 'esol': ('ESOL', 'delaney-processed.csv', 'delaney-processed', -1, -2), + 'freesolv': ('FreeSolv', 'SAMPL.csv', 'SAMPL', 1, 2), + 'lipo': ('Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1), + 'pcba': ('PCBA', 'pcba.csv.gz', 'pcba', -1, slice(0, 128)), + 'muv': ('MUV', 'muv.csv.gz', 'muv', -1, slice(0, 17)), + 'hiv': ('HIV', 'HIV.csv', 'HIV', 0, -1), + 'bace': ('BACE', 'bace.csv', 'bace', 0, 2), + 'bbbp': ('BBBP', 'BBBP.csv', 'BBBP', -1, -2), + 'tox21': ('Tox21', 'tox21.csv.gz', 'tox21', -1, slice(0, 12)), + 'toxcast': + ('ToxCast', 'toxcast_data.csv.gz', 'toxcast_data', 0, slice(1, 618)), + 'sider': ('SIDER', 'sider.csv.gz', 'sider', 0, slice(1, 28)), + 'clintox': ('ClinTox', 'clintox.csv.gz', 'clintox', 0, slice(1, 3)), + } + + def __init__( + self, + root: str, + name: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + ) -> None: + self.name = name.lower() + assert self.name in self.names.keys() + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + @property + def raw_dir(self) -> str: + return osp.join(self.root, self.name, 'raw') + + @property + def processed_dir(self) -> str: + return osp.join(self.root, self.name, 'processed') + + @property + def raw_file_names(self) -> str: + return f'{self.names[self.name][2]}.csv' + + @property + def processed_file_names(self) -> str: + return 'data.pt' + + def download(self) -> None: + url = self.url.format(self.names[self.name][1]) + path = download_url(url, self.raw_dir) + if self.names[self.name][1][-2:] == 'gz': + extract_gz(path, self.raw_dir) + os.unlink(path) + + def process(self) -> None: + with open(self.raw_paths[0], 'r') as f: + dataset = f.read().split('\n')[1:-1] + dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. + + data_list = [] + for line in dataset: + line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. + values = line.split(',') + + smiles = values[self.names[self.name][3]] + labels = values[self.names[self.name][4]] + labels = labels if isinstance(labels, list) else [labels] + + ys = [float(y) if len(y) > 0 else float('NaN') for y in labels] + + y = tlx.convert_to_tensor(ys, dtype=tlx.float32).reshape(1, -1) + data = from_smiles(smiles) + data.y = y + + if self.pre_filter is not None and not self.pre_filter(data): + continue + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save_data(self.collate(data_list), self.processed_paths[0]) + + def __repr__(self) -> str: + return f'{self.names[self.name][0]}({len(self)})' diff --git a/gammagl/datasets/planetoid.py b/gammagl/datasets/planetoid.py index 52f5db33..f81ea197 100644 --- a/gammagl/datasets/planetoid.py +++ b/gammagl/datasets/planetoid.py @@ -60,6 +60,8 @@ class Planetoid(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) Tip --- @@ -94,10 +96,11 @@ class Planetoid(InMemoryDataset): def __init__(self, root: str = None, name: str = 'cora', split: str = "public", num_train_per_class: int = 20, num_val: int = 500, num_test: int = 1000, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, + force_reload: bool = False): self.name = name - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) self.split = split assert self.split in ['public', 'full', 'random'] diff --git a/gammagl/datasets/polblogs.py b/gammagl/datasets/polblogs.py index 9f3791a4..66a1403c 100644 --- a/gammagl/datasets/polblogs.py +++ b/gammagl/datasets/polblogs.py @@ -33,6 +33,8 @@ class PolBlogs(InMemoryDataset): an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) **STATS:** @@ -54,8 +56,8 @@ class PolBlogs(InMemoryDataset): url = 'https://netset.telecom-paris.fr/datasets/polblogs.tar.gz' def __init__(self, root: str = None, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): - super().__init__(root, transform, pre_transform) + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/ppi.py b/gammagl/datasets/ppi.py index 4646e1d3..11ca266a 100644 --- a/gammagl/datasets/ppi.py +++ b/gammagl/datasets/ppi.py @@ -14,11 +14,11 @@ class PPI(InMemoryDataset): url = 'https://data.dgl.ai/dataset/ppi.zip' def __init__(self, root=None, split='train', transform=None, pre_transform=None, - pre_filter=None): + pre_filter=None, force_reload: bool = False): assert split in ['train', 'val', 'valid', 'test'] - super().__init__(root, transform, pre_transform, pre_filter) + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) if split == 'train': self.data, self.slices = self.load_data(self.processed_paths[0]) diff --git a/gammagl/datasets/reddit.py b/gammagl/datasets/reddit.py index 39d44a97..21c7cb7b 100644 --- a/gammagl/datasets/reddit.py +++ b/gammagl/datasets/reddit.py @@ -28,13 +28,15 @@ class Reddit(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ url = 'https://data.dgl.ai/dataset/reddit.zip' - def __init__(self, root=None, transform=None, pre_transform=None): - super().__init__(root, transform, pre_transform) + def __init__(self, root=None, transform=None, pre_transform=None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/shapenet.py b/gammagl/datasets/shapenet.py index 9322978b..0ac23dab 100644 --- a/gammagl/datasets/shapenet.py +++ b/gammagl/datasets/shapenet.py @@ -62,6 +62,8 @@ class ShapeNet(InMemoryDataset): :obj:`gammagl.data.Graph` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ @@ -108,14 +110,14 @@ class ShapeNet(InMemoryDataset): def __init__(self, root=None, categories=None, include_normals=True, split='trainval', transform=None, pre_transform=None, - pre_filter=None): + pre_filter=None, force_reload: bool = False): if categories is None: categories = list(self.category_ids.keys()) if isinstance(categories, str): categories = [categories] assert all(category in self.category_ids for category in categories) self.categories = categories - super().__init__(root, transform, pre_transform, pre_filter) + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) if split == 'train': path = self.processed_paths[0] diff --git a/gammagl/datasets/tu_dataset.py b/gammagl/datasets/tu_dataset.py index cc26e30d..5d4dd15a 100644 --- a/gammagl/datasets/tu_dataset.py +++ b/gammagl/datasets/tu_dataset.py @@ -66,6 +66,8 @@ class TUDataset(InMemoryDataset): cleaned: bool, optional If :obj:`True`, the dataset will contain only non-isomorphic graphs. (default: :obj:`False`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) Tip --- @@ -132,10 +134,11 @@ def __init__(self, root: str = None, name: str = 'MUTAG', pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, use_node_attr: bool = False, use_edge_attr: bool = False, - cleaned: bool = False): + cleaned: bool = False, + force_reload: bool = False): self.name = name self.cleaned = cleaned - super().__init__(root, transform, pre_transform, pre_filter) + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) self.data, self.slices, self.sizes = self.load_data(self.processed_paths[0]) if self.data.x is not None and not use_node_attr: num_node_attributes = self.num_node_attributes diff --git a/gammagl/datasets/webkb.py b/gammagl/datasets/webkb.py index c4403b60..eb9ee2b4 100644 --- a/gammagl/datasets/webkb.py +++ b/gammagl/datasets/webkb.py @@ -32,15 +32,18 @@ class WebKB(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' - def __init__(self, root=None, name='cornell', transform=None, pre_transform=None): + def __init__(self, root=None, name='cornell', transform=None, pre_transform=None, force_reload: bool = False): self.name = name.lower() assert self.name in ['cornell', 'texas', 'wisconsin'] - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/wikics.py b/gammagl/datasets/wikics.py index 3d8ef0ff..f2347ff3 100644 --- a/gammagl/datasets/wikics.py +++ b/gammagl/datasets/wikics.py @@ -10,34 +10,38 @@ class WikiCS(InMemoryDataset): r"""The semi-supervised Wikipedia-based dataset from the - `"Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks" - `_ paper, containing 11,701 nodes, - 216,123 edges, 10 classes and 20 different training splits. + `"Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks" + `_ paper, containing 11,701 nodes, + 216,123 edges, 10 classes and 20 different training splits. - Parameters - ---------- - root: str, optional - Root directory where the dataset should be saved. - transform: callable, optional - A function/transform that takes in an - :obj:`gammagl.data.Graph` object and returns a transformed - version. The data object will be transformed before every access. - (default: :obj:`None`) - pre_transform: callable, optional - A function/transform that takes in - an :obj:`gammagl.data.Graph` object and returns a - transformed version. The data object will be transformed before - being saved to disk. (default: :obj:`None`) - is_undirected: bool, optional - Whether the graph is undirected. - (default: :obj:`True`) - """ + Parameters + ---------- + root: str, optional + Root directory where the dataset should be saved. + transform: callable, optional + A function/transform that takes in an + :obj:`gammagl.data.Graph` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform: callable, optional + A function/transform that takes in + an :obj:`gammagl.data.Graph` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + is_undirected: bool, optional + Whether the graph is undirected. + (default: :obj:`True`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + """ url = 'https://github.com/pmernyei/wiki-cs-dataset/raw/master/dataset' def __init__(self, root: str = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - is_undirected: Optional[bool] = None): + is_undirected: Optional[bool] = None, + force_reload: bool = False): if is_undirected is None: warnings.warn( f"The {self.__class__.__name__} dataset now returns an " @@ -45,7 +49,7 @@ def __init__(self, root: str = None, transform: Optional[Callable] = None, f"'is_undirected=False' to restore the old behavior.") is_undirected = True self.is_undirected = is_undirected - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/wikipedia_network.py b/gammagl/datasets/wikipedia_network.py index 2f5dd57f..f7391858 100644 --- a/gammagl/datasets/wikipedia_network.py +++ b/gammagl/datasets/wikipedia_network.py @@ -43,6 +43,8 @@ class WikipediaNetwork(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) """ @@ -52,14 +54,15 @@ class WikipediaNetwork(InMemoryDataset): def __init__(self, root: str = None, name: str = 'chameleon', geom_gcn_preprocess: bool = True, transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None): + pre_transform: Optional[Callable] = None, + force_reload: bool = False): self.name = name.lower() self.geom_gcn_preprocess = geom_gcn_preprocess assert self.name in ['chameleon', 'crocodile', 'squirrel'] if geom_gcn_preprocess and self.name == 'crocodile': raise AttributeError("The dataset 'crocodile' is not available in " "case 'geom_gcn_preprocess=True'") - super().__init__(root, transform, pre_transform) + super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property diff --git a/gammagl/datasets/zinc.py b/gammagl/datasets/zinc.py index 0424b916..474d5d75 100644 --- a/gammagl/datasets/zinc.py +++ b/gammagl/datasets/zinc.py @@ -63,6 +63,9 @@ class ZINC(InMemoryDataset): :obj:`gammagl.data.Graph` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1' @@ -70,10 +73,10 @@ class ZINC(InMemoryDataset): 'benchmarking-gnns/master/data/molecules/{}.index') def __init__(self, root: str = None, subset=False, split='train', transform=None, - pre_transform=None, pre_filter=None): + pre_transform=None, pre_filter=None, force_reload: bool = False): self.subset = subset assert split in ['train', 'val', 'test'] - super().__init__(root, transform, pre_transform, pre_filter) + super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) path = osp.join(self.processed_dir, f'{split}.pt') self.data, self.slices = self.load_data(path) diff --git a/gammagl/utils/__init__.py b/gammagl/utils/__init__.py index 87e19b09..b9fea04b 100644 --- a/gammagl/utils/__init__.py +++ b/gammagl/utils/__init__.py @@ -16,6 +16,7 @@ from .read_embeddings import read_embeddings from .homophily import homophily from .to_dense_adj import to_dense_adj +from .smiles import from_smiles __all__ = [ 'calc_A_norm_hat', @@ -37,7 +38,8 @@ 'to_scipy_sparse_matrix', 'read_embeddings', 'homophily', - 'to_dense_adj' + 'to_dense_adj', + 'from_smiles' ] diff --git a/gammagl/utils/convert.py b/gammagl/utils/convert.py index 834a727a..26081c89 100644 --- a/gammagl/utils/convert.py +++ b/gammagl/utils/convert.py @@ -29,4 +29,4 @@ def to_scipy_sparse_matrix(edge_index, edge_attr = None, num_nodes = None): assert edge_attr.shape[0] == row.shape[0] num_nodes = maybe_num_nodes(edge_index, num_nodes) - return ssp.coo_matrix((edge_attr, (row, col)), (num_nodes, num_nodes)) \ No newline at end of file + return ssp.coo_matrix((edge_attr, (row, col)), (num_nodes, num_nodes)) diff --git a/gammagl/utils/smiles.py b/gammagl/utils/smiles.py new file mode 100644 index 00000000..58d88e85 --- /dev/null +++ b/gammagl/utils/smiles.py @@ -0,0 +1,144 @@ +from typing import List, Dict, Any +import tensorlayerx as tlx + + +x_map: Dict[str, List[Any]] = { + 'atomic_num': + list(range(0, 119)), + 'chirality': [ + 'CHI_UNSPECIFIED', + 'CHI_TETRAHEDRAL_CW', + 'CHI_TETRAHEDRAL_CCW', + 'CHI_OTHER', + 'CHI_TETRAHEDRAL', + 'CHI_ALLENE', + 'CHI_SQUAREPLANAR', + 'CHI_TRIGONALBIPYRAMIDAL', + 'CHI_OCTAHEDRAL', + ], + 'degree': + list(range(0, 11)), + 'formal_charge': + list(range(-5, 7)), + 'num_hs': + list(range(0, 9)), + 'num_radical_electrons': + list(range(0, 5)), + 'hybridization': [ + 'UNSPECIFIED', + 'S', + 'SP', + 'SP2', + 'SP3', + 'SP3D', + 'SP3D2', + 'OTHER', + ], + 'is_aromatic': [False, True], + 'is_in_ring': [False, True], +} + +e_map: Dict[str, List[Any]] = { + 'bond_type': [ + 'UNSPECIFIED', + 'SINGLE', + 'DOUBLE', + 'TRIPLE', + 'QUADRUPLE', + 'QUINTUPLE', + 'HEXTUPLE', + 'ONEANDAHALF', + 'TWOANDAHALF', + 'THREEANDAHALF', + 'FOURANDAHALF', + 'FIVEANDAHALF', + 'AROMATIC', + 'IONIC', + 'HYDROGEN', + 'THREECENTER', + 'DATIVEONE', + 'DATIVE', + 'DATIVEL', + 'DATIVER', + 'OTHER', + 'ZERO', + ], + 'stereo': [ + 'STEREONONE', + 'STEREOANY', + 'STEREOZ', + 'STEREOE', + 'STEREOCIS', + 'STEREOTRANS', + ], + 'is_conjugated': [False, True], +} + + + +def from_smiles(smiles: str, with_hydrogen: bool = False, + kekulize: bool = False): + r"""Converts a SMILES string to a :class:`gammagl.data.Graph` + instance. + Args: + smiles (str): The SMILES string. + with_hydrogen (bool, optional): If set to :obj:`True`, will store + hydrogens in the molecule graph. (default: :obj:`False`) + kekulize (bool, optional): If set to :obj:`True`, converts aromatic + bonds to single/double bonds. (default: :obj:`False`) + """ + from rdkit import Chem, RDLogger + + # from gammagl.data import + from gammagl.data import Graph + + RDLogger.DisableLog('rdApp.*') + + mol = Chem.MolFromSmiles(smiles) + + if mol is None: + mol = Chem.MolFromSmiles('') + if with_hydrogen: + mol = Chem.AddHs(mol) + if kekulize: + Chem.Kekulize(mol) + + xs: List[List[int]] = [] + for atom in mol.GetAtoms(): + row: List[int] = [] + row.append(x_map['atomic_num'].index(atom.GetAtomicNum())) + row.append(x_map['chirality'].index(str(atom.GetChiralTag()))) + row.append(x_map['degree'].index(atom.GetTotalDegree())) + row.append(x_map['formal_charge'].index(atom.GetFormalCharge())) + row.append(x_map['num_hs'].index(atom.GetTotalNumHs())) + row.append(x_map['num_radical_electrons'].index( + atom.GetNumRadicalElectrons())) + row.append(x_map['hybridization'].index(str(atom.GetHybridization()))) + row.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) + row.append(x_map['is_in_ring'].index(atom.IsInRing())) + xs.append(row) + + x = tlx.convert_to_tensor(xs, dtype=tlx.int64).reshape(-1, 9) + + edge_indices, edge_attrs = [], [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + + e = [] + e.append(e_map['bond_type'].index(str(bond.GetBondType()))) + e.append(e_map['stereo'].index(str(bond.GetStereo()))) + e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) + + edge_indices += [[i, j], [j, i]] + edge_attrs += [e, e] + + edge_index = tlx.convert_to_tensor(edge_indices) + edge_index = edge_index.t().to(tlx.int64).reshape(2, -1) + edge_attr = tlx.convert_to_tensor(edge_attrs, dtype=tlx.int64).reshape(-1, 3) + + if edge_index.numel() > 0: # Sort indices. + perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() + edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] + + return Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles) \ No newline at end of file