Skip to content

Commit

Permalink
fixing long time action flow
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed Dec 24, 2024
1 parent d461cbd commit 2102532
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
15 changes: 10 additions & 5 deletions test/data/load/test_datasetloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def _gather_config_files(self, base_dir: Path) -> List[str]:
config_base_dir = base_dir / "configs/dataset"
# Below the datasets that have some default transforms manually overriten with no_transform,
exclude_datasets = {"karate_club.yaml",
# Below the datasets that takes quite some time to load and process
"mantra_name.yaml", "mantra_orientation.yaml","mantra_genus.yaml", # "mantra_betti_numbers.yaml", "mantra_genus.yaml",
# Below the datasets that have some default transforms with we manually overriten with no_transform,
# due to lack of default transform for domain2domain
"REDDIT-BINARY.yaml", "IMDB-MULTI.yaml", "IMDB-BINARY.yaml", #"ZINC.yaml"
}

# Below the datasets that takes quite some time to load and process
self.long_running_datasets = {"mantra_name.yaml", "mantra_orientation.yaml", "mantra_genus.yaml", "mantra_betti_numbers.yaml",}


for dir_path in config_base_dir.iterdir():
Expand Down Expand Up @@ -79,12 +80,16 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]:
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"],
return_hydra_config=True

return_hydra_config=True,
)
dataset_loader = hydra.utils.instantiate(parameters.dataset.loader)
print(repr(dataset_loader))
return dataset_loader.load()

if config_file in self.long_running_datasets:
dataset, data_dir = dataset_loader.load(slice=100)
else:
dataset, data_dir = dataset_loader.load()
return dataset, data_dir

def test_dataset_loading_states(self):
"""Test different states and scenarios during dataset loading."""
Expand Down
9 changes: 8 additions & 1 deletion topobenchmark/data/datasets/mantra_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class MantraDataset(InMemoryDataset):
Name of the dataset.
parameters : DictConfig
Configuration parameters for the dataset.
**kwargs : dict
Additional keyword arguments.
Attributes
----------
Expand All @@ -50,6 +52,7 @@ def __init__(
root: str,
name: str,
parameters: DictConfig,
**kwargs,
) -> None:
self.parameters = parameters
self.manifold_dim = parameters.manifold_dim
Expand All @@ -58,7 +61,10 @@ def __init__(
self.name = "_".join(
[name, str(self.version), f"manifold_dim_{self.manifold_dim}"]
)

if kwargs.get("slice", None):
self.slice = 100
else:
self.slice = None
super().__init__(
root,
)
Expand Down Expand Up @@ -183,6 +189,7 @@ def process(self) -> None:
osp.join(self.raw_dir, self.raw_file_names[0]),
self.manifold_dim,
self.task_variable,
self.slice,
)

data_list = data
Expand Down
9 changes: 7 additions & 2 deletions topobenchmark/data/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ def load_dataset(self) -> torch_geometric.data.Data:
"""
raise NotImplementedError

def load(self) -> tuple[torch_geometric.data.Data, str]:
def load(self, **kwargs) -> tuple[torch_geometric.data.Data, str]:
"""Load data.
Parameters
----------
**kwargs : dict
Additional keyword arguments.
Returns
-------
tuple[torch_geometric.data.Data, str]
Tuple containing the loaded data and the data directory.
"""
dataset = self.load_dataset()
dataset = self.load_dataset(**kwargs)
data_dir = self.get_data_dir()

return dataset, data_dir
42 changes: 28 additions & 14 deletions topobenchmark/data/loaders/simplicial/mantra_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,31 @@
class MantraSimplicialDatasetLoader(AbstractLoader):
"""Load Mantra dataset with configurable parameters.
Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class!
Parameters
----------
parameters : DictConfig
Configuration parameters containing:
- data_dir: Root directory for data
- data_name: Name of the dataset
- other relevant parameters
Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class!
Parameters
----------
parameters : DictConfig
Configuration parameters containing:
- data_dir: Root directory for data
- data_name: Name of the dataset
- other relevant parameters
**kwargs : dict
Additional keyword arguments.
"""

def __init__(self, parameters: DictConfig) -> None:
super().__init__(parameters)
def __init__(self, parameters: DictConfig, **kwargs) -> None:
super().__init__(parameters, **kwargs)

def load_dataset(self) -> MantraDataset:
def load_dataset(self, **kwargs) -> MantraDataset:
"""Load the Citation Hypergraph dataset.
Parameters
----------
**kwargs : dict
Additional keyword arguments for dataset initialization.
Returns
-------
CitationHypergraphDataset
Expand All @@ -37,13 +45,18 @@ def load_dataset(self) -> MantraDataset:
If dataset loading fails.
"""

dataset = self._initialize_dataset()
dataset = self._initialize_dataset(**kwargs)
self.data_dir = self.get_data_dir()
return dataset

def _initialize_dataset(self) -> MantraDataset:
def _initialize_dataset(self, **kwargs) -> MantraDataset:
"""Initialize the Citation Hypergraph dataset.
Parameters
----------
**kwargs : dict
Additional keyword arguments for dataset initialization.
Returns
-------
CitationHypergraphDataset
Expand All @@ -53,4 +66,5 @@ def _initialize_dataset(self) -> MantraDataset:
root=str(self.root_data_dir),
name=self.parameters.data_name,
parameters=self.parameters,
**kwargs,
)
6 changes: 4 additions & 2 deletions topobenchmark/data/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def download_file_from_link(
print("Failed to download the file.")


def read_ndim_manifolds(path, dim, y_val="betti_numbers"):
def read_ndim_manifolds(path, dim, y_val="betti_numbers", slice=None):
"""Load MANTRA dataset.
Parameters
Expand All @@ -127,6 +127,8 @@ def read_ndim_manifolds(path, dim, y_val="betti_numbers"):
y_val : str, optional
The triangulation information to use as label. Can be one of ['betti_numbers', 'torsion_coefficients',
'name', 'genus', 'orientable'] (default: "orientable").
slice : int, optional
Slice of the dataset to load. If None, load the entire dataset (default: None). Used for testing.
Returns
-------
Expand Down Expand Up @@ -171,7 +173,7 @@ def read_ndim_manifolds(path, dim, y_val="betti_numbers"):

data_list = []
# For each manifold
for manifold in manifold_list:
for manifold in manifold_list[:slice]:
n_vertices = manifold["n_vertices"]
x = torch.ones(n_vertices, 1)
y_value = manifold[y_val]
Expand Down

0 comments on commit 2102532

Please sign in to comment.