-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_geospatial.py
38 lines (31 loc) · 1.15 KB
/
create_geospatial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
from functools import partial
import hydra
from omegaconf import DictConfig
from torch_geometric.transforms import Compose
from etnn.geospatial import pm25cc, transforms
logger = logging.getLogger(__name__)
@hydra.main(config_path="conf/conf_geospatial", config_name="config", version_base=None)
def main(cfg: DictConfig):
pre_transform = []
if cfg.dataset.standardize:
pre_transform.append(transforms.standardize_cc)
if cfg.dataset.randomize_x0:
pre_transform.append(partial(transforms.randomize, keys=["x_0"]))
if cfg.dataset.virtual_node:
pre_transform.append(transforms.add_virtual_node)
if cfg.dataset.squash_to_graph:
pre_transform.append(transforms.squash_cc)
if cfg.dataset.add_positions:
pre_transform.append(transforms.add_pos_to_cc)
pre_transform = Compose(pre_transform)
dataset = pm25cc.PM25CC(
f"data/geospatialcc_{cfg.dataset_name}",
pre_transform=pre_transform,
force_reload=cfg.force_reload,
)
logger.info(
f"Created GeospatialCC dataset generated and stored in '{dataset.root}'."
)
if __name__ == "__main__":
main()