Skip to content

Commit

Permalink
[Dataset] Add dataset (#192)
Browse files Browse the repository at this point in the history
* 2.1

* datasets修改

给datasets中的所有属性添加force_reload参数

* remove files

* update

* update

* 2.2修改
  • Loading branch information
jyyy6565 authored Jan 21, 2024
1 parent eaa51ac commit b9e4cd6
Show file tree
Hide file tree
Showing 28 changed files with 495 additions and 72 deletions.
14 changes: 10 additions & 4 deletions gammagl/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Dataset(Dataset):
:obj:`gammagl.data.Graph` object and returns a boolean
value, indicating whether the graph object should be included in the
final dataset. (default: :obj:`None`)
force_reload: bool, optional
Whether to re-process the dataset.(default: :obj:`False`)
"""

Expand Down Expand Up @@ -82,7 +84,8 @@ def get(self, idx: int) -> Graph:
def __init__(self, root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
pre_filter: Optional[Callable] = None,
force_reload: bool = False):
super().__init__()

self.raw_root = root
Expand All @@ -98,6 +101,7 @@ def __init__(self, root: Optional[str] = None,
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self._indices: Optional[Sequence] = None
self.force_reload = force_reload

# when finished,record dataset path to .ggl/datasets.json
# next time will use this dataset to avoid download repeatedly.
Expand Down Expand Up @@ -316,17 +320,19 @@ def _process(self):
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"sure to delete '{self.processed_dir}' first")
f"sure to delete '{self.processed_dir}' first"
f"`force_reload=True` explicitly to reload the dataset.")

f = osp.join(self.processed_dir, tlx.BACKEND + '_pre_filter.pt')
if osp.exists(f) and self.load_with_pickle(f) != _repr(self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in the "
"pre-processed version of this dataset. If you want to make "
"use of another pre-fitering technique, make sure to delete "
"'{self.processed_dir}' first")
"'{self.processed_dir}' first"
"`force_reload=True` explicitly to reload the dataset.")

if files_exist(self.processed_paths): # pragma: no cover
if not self.force_reload and files_exist(self.processed_paths): # pragma: no cover
# self.process()
return

Expand Down
8 changes: 6 additions & 2 deletions gammagl/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class InMemoryDataset(Dataset):
:obj:`gammagl.data.Graph` object and returns a boolean
value, indicating whether the graph object should be included in the
final dataset. (default: :obj:`None`)
force_reload: bool, optional
Whether to re-process the dataset.(default: :obj:`False`)
"""
@property
Expand All @@ -54,12 +56,14 @@ def process(self):
def __init__(self, root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
super().__init__(root, transform, pre_transform, pre_filter)
pre_filter: Optional[Callable] = None,
force_reload: bool = False):
super().__init__(root, transform, pre_transform, pre_filter, force_reload)
self.data = None
self.slices = None
self._data_list: Optional[List[Graph]] = None


@property
def num_classes(self) -> int:
r"""Returns the number of classes in the dataset."""
Expand Down
4 changes: 3 additions & 1 deletion gammagl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .aminer import AMiner
from .polblogs import PolBlogs
from .wikics import WikiCS
from .molecule_net import MoleculeNet

__all__ = [
'Amazon',
Expand All @@ -37,7 +38,8 @@
'ZINC',
'AMiner',
'PolBlogs',
'WikiCS'
'WikiCS',
'MoleculeNet'
]

classes = __all__
4 changes: 2 additions & 2 deletions gammagl/datasets/alircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class AliRCD(InMemoryDataset):
"""
url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/AliRCD_session1.zip"

def __init__(self, root=None, transform=None, pre_transform=None):
def __init__(self, root=None, transform=None, pre_transform=None, force_reload: bool = False):
self.edge_size = 157814864
self.node_size = 13806619
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
7 changes: 5 additions & 2 deletions gammagl/datasets/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Amazon(InMemoryDataset):
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
Stats:
.. list-table::
Expand Down Expand Up @@ -58,10 +60,11 @@ class Amazon(InMemoryDataset):

def __init__(self, root: str = None, name: str = 'computers',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
pre_transform: Optional[Callable] = None,
force_reload: bool = False):
self.name = name.lower()
assert self.name in ['computers', 'photo']
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
7 changes: 5 additions & 2 deletions gammagl/datasets/aminer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ class AMiner(InMemoryDataset):
an :obj:`gammagl.data.HeteroGraph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
y_url = 'https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1'

def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None):
super().__init__(root, transform, pre_transform, pre_filter)
pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None,
force_reload: bool = False):
super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
7 changes: 5 additions & 2 deletions gammagl/datasets/coauthor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Coauthor(InMemoryDataset):
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
Stats:
.. list-table::
Expand Down Expand Up @@ -58,10 +60,11 @@ class Coauthor(InMemoryDataset):

def __init__(self, root: str = None, name: str = 'cs',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
pre_transform: Optional[Callable] = None,
force_reload: bool = False):
assert name.lower() in ['cs', 'physics']
self.name = 'CS' if name.lower() == 'cs' else 'Physics'
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
6 changes: 4 additions & 2 deletions gammagl/datasets/dblp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ class DBLP(InMemoryDataset):
an :obj:`gammagl.data.HeteroGraph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1'

def __init__(self, root: str = None, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
pre_transform: Optional[Callable] = None, force_reload: bool = False):
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
6 changes: 4 additions & 2 deletions gammagl/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ class Entities(InMemoryDataset):
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://data.dgl.ai/dataset/{}.tgz'

def __init__(self, root: str = None, name: str = 'aifb', hetero: bool = False,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
pre_transform: Optional[Callable] = None, force_reload: bool = False):
self.name = name.lower()
self.hetero = hetero
assert self.name in ['aifb', 'am', 'mutag', 'bgs']
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
6 changes: 4 additions & 2 deletions gammagl/datasets/flickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Flickr(InMemoryDataset):
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
Tip
---
Expand All @@ -51,8 +53,8 @@ class Flickr(InMemoryDataset):
role_id = '1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7'

def __init__(self, root: str = None, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
pre_transform: Optional[Callable] = None, force_reload: bool = False):
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
7 changes: 5 additions & 2 deletions gammagl/datasets/hgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class HGBDataset(InMemoryDataset):
an :class:`gammmgl.transform` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

Expand All @@ -100,10 +102,11 @@ class HGBDataset(InMemoryDataset):

def __init__(self, root: str = None, name: str = 'acm',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
pre_transform: Optional[Callable] = None,
force_reload: bool = False):
self.name = name.lower()
assert self.name in set(self.names.keys())
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
6 changes: 4 additions & 2 deletions gammagl/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ class IMDB(InMemoryDataset):
an :obj:`gammagl.data.HeteroGraph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://www.dropbox.com/s/g0btk9ctr1es39x/IMDB_processed.zip?dl=1'

def __init__(self, root: str = None, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
pre_transform: Optional[Callable] = None, force_reload: bool = False):
super().__init__(root, transform, pre_transform, force_reload = force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
Expand Down
4 changes: 2 additions & 2 deletions gammagl/datasets/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MLDataset(InMemoryDataset):
# url = 'https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-100k.zip'

def __init__(self, root=None, split='train', transform=None, pre_transform=None,
pre_filter=None, dataset_name='ml-100k'):
pre_filter=None, dataset_name='ml-100k', force_reload: bool = False):
assert split in ['train', 'val', 'valid', 'test']

assert dataset_name in ['ml-100k', 'ml-1m', 'ml-10m', 'ml-20m']
Expand All @@ -21,7 +21,7 @@ def __init__(self, root=None, split='train', transform=None, pre_transform=None,
url_post = '.zip'
self.url = f'{url_pre}{self.dataset_name}{url_post}'

super().__init__(root, transform, pre_transform, pre_filter)
super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload)

self.data, self.slices = self.load_data(self.processed_paths[0])

Expand Down
6 changes: 4 additions & 2 deletions gammagl/datasets/modelnet40.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ class ModelNet40(InMemoryDataset):
(:obj:`"train"`, :obj:`"test"`).
num_points: int, optional
The number of points used to train or test.
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
url = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'

def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, split='train', num_points=1024):
def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, split='train', num_points=1024, force_reload: bool = False):
self.num_points = num_points
self.split = split
super(ModelNet40, self).__init__(root, transform, pre_transform, pre_filter)
super(ModelNet40, self).__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload)
assert split in ['train', 'test']
path = self.processed_paths[0] if split == 'train' else self.processed_paths[1]
self.data, self.slices = self.load_data(path)
Expand Down
Loading

0 comments on commit b9e4cd6

Please sign in to comment.