Skip to content

Commit

Permalink
fix linter errors in data_collection
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Aug 21, 2024
1 parent 4ea862e commit 3ec4e87
Showing 1 changed file with 60 additions and 34 deletions.
94 changes: 60 additions & 34 deletions dgl_ptm/dgl_ptm/model/data_collection.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
import xarray as xr
from pathlib import Path
"""The module to collect data from agents and edges."""

import os
import dgl
from pathlib import Path

def data_collection(agent_graph, timestep, npath='./agent_data', epath='./edge_data', ndata = ['all'], edata = ['all'], format = 'xarray', mode = 'w-'):
'''
data_collection - collects data from agents and edges.
import xarray as xr

Args:
agent_graph: DGLGraph with agent nodes and edges connecting agents.
timestep = current timestep to name folder for edge properties
npath = path to store node data.
epath = path to store edge data with one file for each timestep.
ndata = node data properties to be stored.
['all'] implies all node properties will be saved
edata = edge data properties to be stored.
['all'] implies all edge properties will be saved
format = storage format
['xarray'] saves the properties in zarr format with xarray dataset
mode = zarr write mode.

Output:
def data_collection(agent_graph,
timestep,
npath='./agent_data',
epath='./edge_data',
ndata=None,
edata=None,
format = 'xarray',
mode = 'w-'):
"""data_collection - collects data from agents and edges.
'''
if ndata == ['all']:
Args:
agent_graph: DGLGraph with agent nodes and edges connecting agents.
timestep: current timestep to name folder for edge properties
npath: path to store node data.
epath: path to store edge data with one file for each timestep.
ndata: node data properties to be stored.
['all'] implies all node properties will be saved
edata: edge data properties to be stored.
['all'] implies all edge properties will be saved
format: storage format
['xarray'] saves the properties in zarr format with xarray dataset
mode: zarr write mode.
"""
if ndata is None or ndata == ['all']:
ndata = list(agent_graph.node_attr_schemes().keys())
if ndata[0] == 'all_except':
ndata = list(agent_graph.node_attr_schemes().keys() - ndata[1])
if edata == ['all']:
if edata is None or edata == ['all']:
edata = list(agent_graph.edge_attr_schemes().keys())

_node_property_collector(agent_graph, npath, ndata, timestep, format, mode)
Expand All @@ -40,40 +46,60 @@ def _node_property_collector(agent_graph, npath, ndata, timestep, format, mode):
agent_data_instance = xr.Dataset()
for prop in ndata:
_check_nprop_in_graph(agent_graph, prop)
agent_data_instance = agent_data_instance.assign(prop=(['n_agents','n_time'], agent_graph.ndata[prop][:,None].cpu().numpy()))
agent_data_instance = agent_data_instance.rename(name_dict={'prop':prop})
agent_data_instance = agent_data_instance.assign(
prop=(['n_agents','n_time'], agent_graph.ndata[prop][:,None].cpu().numpy()) # noqa: E501
)
agent_data_instance = agent_data_instance.rename(
name_dict={'prop':prop}
)
if timestep == 0:
agent_data_instance.to_zarr(npath, mode = mode)
else:
agent_data_instance.to_zarr(npath, append_dim='n_time')
else:
raise NotImplementedError("Only 'xarray' format currrent available")
else:
raise NotImplementedError("Data collection currently only implemented for pytorch backend")
raise NotImplementedError(
"Data collection currently only implemented for pytorch backend"
)


def _edge_property_collector(agent_graph, epath, edata, timestep, format, mode):
if os.environ["DGLBACKEND"] == "pytorch":
if format == 'xarray':
edge_data_instance = xr.Dataset(coords=dict(
source=(["n_edges"], agent_graph.edges()[0].cpu()),
dest=(["n_edges"], agent_graph.edges()[1].cpu()),
))
edge_data_instance = xr.Dataset(
coords=dict(
source=(["n_edges"], agent_graph.edges()[0].cpu()),
dest=(["n_edges"], agent_graph.edges()[1].cpu()),
)
)
for prop in edata:
_check_eprop_in_graph(agent_graph, prop)
edge_data_instance = edge_data_instance.assign(property=(['n_edges','time'], agent_graph.edata[prop][:,None].cpu().numpy()))
edge_data_instance = edge_data_instance.assign(
property=(['n_edges','time'], agent_graph.edata[prop][:,None].cpu().numpy()) # noqa: E501
)

edge_data_instance = edge_data_instance.rename_vars(name_dict={'property':prop})
edge_data_instance = edge_data_instance.rename_vars(
name_dict={'property':prop}
)
edge_data_instance.to_zarr(Path(epath)/(str(timestep)+'.zarr'), mode = mode)
else:
raise NotImplementedError("Only 'xarray' mode currrent available")
else:
raise NotImplementedError("Data collection currently only implemented for pytorch backend")
raise NotImplementedError(
"Data collection currently only implemented for pytorch backend"
)

def _check_nprop_in_graph(agent_graph, prop):
if prop not in agent_graph.node_attr_schemes().keys():
raise ValueError(f"{prop} is not a node property. Please choose from {agent_graph.node_attr_schemes().keys()}")
raise ValueError(
f"{prop} is not a node property."
f"Please choose from {agent_graph.node_attr_schemes().keys()}"
)

def _check_eprop_in_graph(agent_graph, prop):
if prop not in agent_graph.edge_attr_schemes().keys():
raise ValueError(f"{prop} is not an edge property. Please choose from {agent_graph.edge_attr_schemes().keys()}")
raise ValueError(
f"{prop} is not an edge property."
f"Please choose from {agent_graph.edge_attr_schemes().keys()}"
)

0 comments on commit 3ec4e87

Please sign in to comment.