Skip to content

Commit

Permalink
fix linter errors in test_model
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Aug 21, 2024
1 parent 2b2a3ed commit 4ea862e
Showing 1 changed file with 98 additions and 48 deletions.
146 changes: 98 additions & 48 deletions dgl_ptm/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
import pytest
import dgl_ptm
import os
import xarray as xr
import shutil
from pathlib import Path

import pytest
import torch
import xarray as xr

import dgl_ptm
from dgl_ptm.model.data_collection import data_collection
from dgl_ptm.model.initialize_model import sample_distribution_tensor
from pathlib import Path

os.environ["DGLBACKEND"] = "pytorch"

@pytest.fixture
def ptm_model():
model = dgl_ptm.PovertyTrapModel(model_identifier='ptm_step', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='ptm_step', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.initialize_model()
return model

@pytest.fixture
def data_collection_model():
model = dgl_ptm.PovertyTrapModel(model_identifier='data_collection', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='data_collection', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.initialize_model()
return model

@pytest.fixture
def initialize_model_model():
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.initialize_model()
return model
Expand Down Expand Up @@ -104,7 +110,8 @@ def test_data_collection_period_and_list(self, ptm_model):
model.config.step_target = 10 # run the model till step 10

# Set periodical progress check as well as
# collecting data before and after specific step and at the end of the process.
# collecting data before and after specific step and at the end of the
# process.
# Note that the pediod and list could have overlapping values;
# this will result in collecting the data once at that step.
model.steering_parameters['data_collection_period'] = 4
Expand All @@ -128,9 +135,13 @@ def test_data_collection_period_and_list(self, ptm_model):
class TestDataCollection:
def test_data_collection(self, data_collection_model):
model = data_collection_model
data_collection(model.graph, timestep=0, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
data_collection(model.graph,
timestep=0,
npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'],
ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'],
format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('test_models/data_collection/agent_data.zarr').exists()
Expand All @@ -139,9 +150,13 @@ def test_data_collection(self, data_collection_model):
def test_data_collection_timestep1(self, data_collection_model):
model = data_collection_model
model.step() # timestep 0
data_collection(model.graph, timestep=1, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
data_collection(model.graph,
timestep=1,
npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'],
ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'],
format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('test_models/data_collection/agent_data.zarr').exists()
Expand All @@ -159,20 +174,25 @@ def test_data_collection_timestep1(self, data_collection_model):

class TestInitializeModel:
def test_set_model_parameters(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)

assert model.step_count == 0
assert Path(model.steering_parameters['npath']) == Path('test_models/initialize_model/agent_data.zarr')
assert Path(model.steering_parameters['epath']) == Path('test_models/initialize_model/edge_data')
assert Path('test_models/initialize_model/initialize_model_0.yaml').exists()
initialize_dir ='test_models/initialize_model'
assert Path(model.steering_parameters['npath']) == Path(f'{initialize_dir}/agent_data.zarr') # noqa: E501
assert Path(model.steering_parameters['epath']) == Path(f'{initialize_dir}/edge_data') # noqa: E501
assert Path(f'{initialize_dir}/initialize_model_0.yaml').exists()
assert model.steering_parameters['edata'] == ['all']
assert model.steering_parameters['format'] == 'xarray'
assert model.config.number_agents == 100
assert model.config.step_target == 5

def test_set_model_parameters_with_file(self, config_file):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(parameter_file_path=config_file)

assert model._model_identifier == 'initialize_model'
Expand All @@ -182,33 +202,42 @@ def test_set_model_parameters_with_file(self, config_file):
assert model.config.number_agents == 100

def test_set_model_parameters_with_kwargs(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model.set_model_parameters(steering_parameters={'del_method': 'probability','del_threshold': 0.04})
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(
steering_parameters={'del_method': 'probability','del_threshold': 0.04}
)

assert model.steering_parameters['del_method'] == 'probability'
assert model.steering_parameters['del_threshold'] == 0.04
assert model.config.number_agents == 100

def test_set_model_parameters_with_file_and_kwargs(self, config_file):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(
parameter_file_path=config_file,
steering_parameters={'del_method': 'probability','del_threshold': 0.06}
)

assert model._model_identifier == 'initialize_model' # Note, not 'new_model' as set in config_file.
# Note, not 'new_model' as set in config_file.
assert model._model_identifier == 'initialize_model'
assert model.steering_parameters['del_method'] == 'probability'
assert model.steering_parameters['del_threshold'] == 0.06
assert model.steering_parameters['step_type'] == 'custom'
assert model.config.number_agents == 100

def test_save_model_parameters(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.save_model_parameters(overwrite=False)
model.save_model_parameters()

# Saving to unique config files becomes useful when starting multiple runs from the same step.
# Saving to unique config files becomes useful when starting multiple
# runs from the same step.
assert Path('test_models/initialize_model/initialize_model_0.yaml').exists()
assert Path('test_models/initialize_model/initialize_model_0_1.yaml').exists()
assert Path('test_models/initialize_model/initialize_model_0_2.yaml').exists()
Expand All @@ -220,22 +249,26 @@ def test_initialize_model(self, initialize_model_model):
assert str(model.graph.device) == 'cpu'

def test_create_network(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.create_network()

assert model.graph is not None
assert model.graph.number_of_nodes() == 100

def test_initialize_model_properties(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.initialize_model_properties()

modelTheta = torch.tensor([1., 1., 1., 1., 1.])

def test_initialize_agent_properties(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='initialize_model', root_path='test_models')
model = dgl_ptm.PovertyTrapModel(
model_identifier='initialize_model', root_path='test_models'
)
model.set_model_parameters(overwrite=True)
model.create_network()
model.initialize_agent_properties()
Expand Down Expand Up @@ -274,14 +307,19 @@ def test_model_init_savestate(self, initialize_model_model):
assert Path('test_models/initialize_model/graph.bin').exists()
assert Path('test_models/initialize_model/generator_state.bin').exists()
assert Path('test_models/initialize_model/process_version.md').exists()
assert model.inputs["step_count"] == 4 # Note that the inputs are set at the end of the last step, which is the step before the step target.

# Note that the inputs are set at the end of the last step, which is the
# step before the step target.
assert model.inputs["step_count"] == 4

def test_model_init_savestate_not_default(self, initialize_model_model):
model = initialize_model_model
model.config.checkpoint_period = 2
model.run()

assert model.inputs["step_count"] == 4 # Note that the inputs are set at the end of the last step, which is the step before the step target.
# Note that the inputs are set at the end of the last step, which is the
# step before the step target.
assert model.inputs["step_count"] == 4

def test_model_init_restart(self, initialize_model_model):
model = initialize_model_model
Expand All @@ -303,19 +341,25 @@ def test_model_init_restart(self, initialize_model_model):
stored_generator_state = set(model.inputs["generator_state"].tolist())

assert model.inputs is not None
assert model.inputs["step_count"] == 4 # Note that the inputs are set at the end of the last step, which is the step before the step target.
# Note that the inputs are set at the end of the last step, which is the
# step before the step target.
assert model.inputs["step_count"] == 4

assert stored_generator_state == expected_generator_state
assert Path('test_models/initialize_model/initialize_model_2.yaml').exists() # The second run also saves its config at the start.

# The second run also saves its config at the start.
assert Path('test_models/initialize_model/initialize_model_2.yaml').exists()

def test_model_milestone(self, initialize_model_model):
model = initialize_model_model
model.config.milestones = [2]
model.run()

assert model.inputs is not None
assert Path('test_models/initialize_model/milestone_2/graph.bin').exists()
assert Path('test_models/initialize_model/milestone_2/generator_state.bin').exists()
assert Path('test_models/initialize_model/milestone_2/process_version.md').exists()
milstone_dir = 'test_models/initialize_model/milestone_2'
assert Path(f'{milstone_dir}/graph.bin').exists()
assert Path(f'{milstone_dir}/generator_state.bin').exists()
assert Path(f'{milstone_dir}/process_version.md').exists()
assert model.inputs["step_count"] == 2

def test_model_milestone_continue(self, initialize_model_model):
Expand All @@ -330,7 +374,7 @@ def test_model_milestone_continue(self, initialize_model_model):

model.initialize_model(restart=(1,0))
model.config.step_target = 5 # continue the model and run till step 5
model.config.description = "Policy 0: just run for a while using default parameters."
model.config.description = "Policy 0: just run using default parameters."
assert model.step_count == 1 # The step count is that of the milestone.
model.run()
assert model.config.step_target == 5
Expand All @@ -340,7 +384,9 @@ def test_model_milestone_continue(self, initialize_model_model):
assert model.inputs["step_count"] == 1
assert model.step_count == 5
assert stored_generator_state == expected_generator_state
assert Path('test_models/initialize_model/initialize_model_1.yaml').exists() # The second run also saves its config at the start.

# The second run also saves its config at the start.
assert Path('test_models/initialize_model/initialize_model_1.yaml').exists()

def test_model_milestone_multiple(self, initialize_model_model):
model = initialize_model_model
Expand All @@ -362,14 +408,16 @@ def test_model_milestone_multiple(self, initialize_model_model):
assert model.inputs["step_count"] == 3

# Note, the first instance of a milestone at step 3 is stored in milestone_3
assert Path('test_models/initialize_model/milestone_3/graph.bin').exists()
assert Path('test_models/initialize_model/milestone_3/generator_state.bin').exists()
assert Path('test_models/initialize_model/milestone_3/process_version.md').exists()
milstone_dir = 'test_models/initialize_model/milestone_3'
assert Path(f'{milstone_dir}/graph.bin').exists()
assert Path(f'{milstone_dir}/generator_state.bin').exists()
assert Path(f'{milstone_dir}/process_version.md').exists()

# Note, the second instance of a milestone at step 3 is stored in milestone_3_1
assert Path('test_models/initialize_model/milestone_3_1/graph.bin').exists()
assert Path('test_models/initialize_model/milestone_3_1/generator_state.bin').exists()
assert Path('test_models/initialize_model/milestone_3_1/process_version.md').exists()
milstone_dir = 'test_models/initialize_model/milestone_3_1'
assert Path(f'{milstone_dir}/graph.bin').exists()
assert Path(f'{milstone_dir}/generator_state.bin').exists()
assert Path(f'{milstone_dir}/process_version.md').exists()

# Continue from the second milestone.
model.initialize_model(restart=(3,1))
Expand All @@ -382,4 +430,6 @@ def test_model_milestone_multiple(self, initialize_model_model):
assert model.inputs["step_count"] == 3
assert model.step_count == 5
assert stored_generator_state == expected_generator_state
assert Path('test_models/initialize_model/initialize_model_3.yaml').exists() # The third run also saves its config at the start.

# The third run also saves its config at the start.
assert Path('test_models/initialize_model/initialize_model_3.yaml').exists()

0 comments on commit 4ea862e

Please sign in to comment.