diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index 4fa52a83..058d5001 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -38,12 +38,13 @@ def _gather_config_files(self, base_dir: Path) -> List[str]: config_base_dir = base_dir / "configs/dataset" # Below the datasets that have some default transforms manually overriten with no_transform, exclude_datasets = {"karate_club.yaml", - # Below the datasets that takes quite some time to load and process - "mantra_name.yaml", "mantra_orientation.yaml","mantra_genus.yaml", # "mantra_betti_numbers.yaml", "mantra_genus.yaml", # Below the datasets that have some default transforms with we manually overriten with no_transform, # due to lack of default transform for domain2domain "REDDIT-BINARY.yaml", "IMDB-MULTI.yaml", "IMDB-BINARY.yaml", #"ZINC.yaml" } + + # Below the datasets that takes quite some time to load and process + self.long_running_datasets = {"mantra_name.yaml", "mantra_orientation.yaml", "mantra_genus.yaml", "mantra_betti_numbers.yaml",} for dir_path in config_base_dir.iterdir(): @@ -79,12 +80,16 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]: parameters = hydra.compose( config_name="run.yaml", overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"], - return_hydra_config=True - + return_hydra_config=True, ) dataset_loader = hydra.utils.instantiate(parameters.dataset.loader) print(repr(dataset_loader)) - return dataset_loader.load() + + if config_file in self.long_running_datasets: + dataset, data_dir = dataset_loader.load(slice=100) + else: + dataset, data_dir = dataset_loader.load() + return dataset, data_dir def test_dataset_loading_states(self): """Test different states and scenarios during dataset loading.""" diff --git a/topobenchmark/data/datasets/mantra_dataset.py b/topobenchmark/data/datasets/mantra_dataset.py index 6d98b74a..c68dfa10 100644 --- a/topobenchmark/data/datasets/mantra_dataset.py +++ b/topobenchmark/data/datasets/mantra_dataset.py @@ -25,6 +25,8 @@ class MantraDataset(InMemoryDataset): Name of the dataset. parameters : DictConfig Configuration parameters for the dataset. + **kwargs : dict + Additional keyword arguments. Attributes ---------- @@ -50,6 +52,7 @@ def __init__( root: str, name: str, parameters: DictConfig, + **kwargs, ) -> None: self.parameters = parameters self.manifold_dim = parameters.manifold_dim @@ -58,7 +61,10 @@ def __init__( self.name = "_".join( [name, str(self.version), f"manifold_dim_{self.manifold_dim}"] ) - + if kwargs.get("slice", None): + self.slice = 100 + else: + self.slice = None super().__init__( root, ) @@ -183,6 +189,7 @@ def process(self) -> None: osp.join(self.raw_dir, self.raw_file_names[0]), self.manifold_dim, self.task_variable, + self.slice, ) data_list = data diff --git a/topobenchmark/data/loaders/base.py b/topobenchmark/data/loaders/base.py index 7f4446fe..66b08cb0 100755 --- a/topobenchmark/data/loaders/base.py +++ b/topobenchmark/data/loaders/base.py @@ -45,15 +45,20 @@ def load_dataset(self) -> torch_geometric.data.Data: """ raise NotImplementedError - def load(self) -> tuple[torch_geometric.data.Data, str]: + def load(self, **kwargs) -> tuple[torch_geometric.data.Data, str]: """Load data. + Parameters + ---------- + **kwargs : dict + Additional keyword arguments. + Returns ------- tuple[torch_geometric.data.Data, str] Tuple containing the loaded data and the data directory. """ - dataset = self.load_dataset() + dataset = self.load_dataset(**kwargs) data_dir = self.get_data_dir() return dataset, data_dir diff --git a/topobenchmark/data/loaders/simplicial/mantra_dataset_loader.py b/topobenchmark/data/loaders/simplicial/mantra_dataset_loader.py index 47b0858c..b264b611 100644 --- a/topobenchmark/data/loaders/simplicial/mantra_dataset_loader.py +++ b/topobenchmark/data/loaders/simplicial/mantra_dataset_loader.py @@ -9,23 +9,31 @@ class MantraSimplicialDatasetLoader(AbstractLoader): """Load Mantra dataset with configurable parameters. - Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class! - - Parameters - ---------- - parameters : DictConfig - Configuration parameters containing: - - data_dir: Root directory for data - - data_name: Name of the dataset - - other relevant parameters + Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class! + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - other relevant parameters + + **kwargs : dict + Additional keyword arguments. """ - def __init__(self, parameters: DictConfig) -> None: - super().__init__(parameters) + def __init__(self, parameters: DictConfig, **kwargs) -> None: + super().__init__(parameters, **kwargs) - def load_dataset(self) -> MantraDataset: + def load_dataset(self, **kwargs) -> MantraDataset: """Load the Citation Hypergraph dataset. + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + Returns ------- CitationHypergraphDataset @@ -37,13 +45,18 @@ def load_dataset(self) -> MantraDataset: If dataset loading fails. """ - dataset = self._initialize_dataset() + dataset = self._initialize_dataset(**kwargs) self.data_dir = self.get_data_dir() return dataset - def _initialize_dataset(self) -> MantraDataset: + def _initialize_dataset(self, **kwargs) -> MantraDataset: """Initialize the Citation Hypergraph dataset. + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + Returns ------- CitationHypergraphDataset @@ -53,4 +66,5 @@ def _initialize_dataset(self) -> MantraDataset: root=str(self.root_data_dir), name=self.parameters.data_name, parameters=self.parameters, + **kwargs, ) diff --git a/topobenchmark/data/utils/io_utils.py b/topobenchmark/data/utils/io_utils.py index d63db391..7e33d31b 100644 --- a/topobenchmark/data/utils/io_utils.py +++ b/topobenchmark/data/utils/io_utils.py @@ -115,7 +115,7 @@ def download_file_from_link( print("Failed to download the file.") -def read_ndim_manifolds(path, dim, y_val="betti_numbers"): +def read_ndim_manifolds(path, dim, y_val="betti_numbers", slice=None): """Load MANTRA dataset. Parameters @@ -127,6 +127,8 @@ def read_ndim_manifolds(path, dim, y_val="betti_numbers"): y_val : str, optional The triangulation information to use as label. Can be one of ['betti_numbers', 'torsion_coefficients', 'name', 'genus', 'orientable'] (default: "orientable"). + slice : int, optional + Slice of the dataset to load. If None, load the entire dataset (default: None). Used for testing. Returns ------- @@ -171,7 +173,7 @@ def read_ndim_manifolds(path, dim, y_val="betti_numbers"): data_list = [] # For each manifold - for manifold in manifold_list: + for manifold in manifold_list[:slice]: n_vertices = manifold["n_vertices"] x = torch.ones(n_vertices, 1) y_value = manifold[y_val]