Skip to content

Commit

Permalink
Merge branch 'development' into 53_save_model_state_customization
Browse files Browse the repository at this point in the history
  • Loading branch information
thijsvl committed Aug 20, 2024
2 parents 480763c + 6f12b26 commit 3c8eae4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 18 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Linting with ruff # settings are in pyproject.toml

on:
pull_request: # only run on pull requests for now
branches: [ "master", "development" ]

jobs:
lint:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./dgl_ptm
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.11"
- name: Get changed python files in dgl_ptm # for now, we only lint the changed files
id: files
run: |
echo "files=$(git diff --name-only --relative --diff-filter=d origin/$GITHUB_BASE_REF...origin/$GITHUB_HEAD_REF -- "*.py" | tr '\n' ' ')" >> $GITHUB_ENV
- name: Install dependencies # ruff tools are installed in the dev dependencies
if: env.files != ''
run: |
python -m pip install --upgrade pip
python -m pip install .[dev]
- name: Check code style against standards
if: env.files != ''
run: |
ruff check $(echo ${{ env.files }} | tr ' ' '\n')
- name: Print message when files is empty
if: env.files == ''
run: echo "No python files were changed."
11 changes: 7 additions & 4 deletions dgl_ptm/dgl_ptm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""Documentation about dgl_ptm"""
"""Documentation about dgl_ptm."""
import logging

#from dgl_ptm.model import initialize_model, step
from dgl_ptm.model.initialize_model import PovertyTrapModel
# from dgl_ptm.agent import agent_update
# from dgl_ptm.agentInteraction import trade_money
# from dgl_ptm.network import global_attachment, link_deletion
# from dgl_ptm.model import initialize_model, data_collection, step
# from dgl_ptm.util import *
from dgl_ptm import config

#from dgl_ptm.model import initialize_model, step
from dgl_ptm.model.initialize_model import PovertyTrapModel

__all__ = ['config', 'PovertyTrapModel']

logging.getLogger(__name__).addHandler(logging.NullHandler())

__author__ = "Team Atlas"
__email__ = "p.chandramouli@esciencecenter.nl"
__email__ = "m.grootes@esciencecenter.nl"
__version__ = "0.1.0"
10 changes: 3 additions & 7 deletions dgl_ptm/dgl_ptm/agent/wealth_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _nn_bellman_past_shock_consumption(model_graph,model_params, timestep, devi
if model_params['nn_path']==None:
print("No consumption model path provided!")

#load model
# Load model

estimator,scale = load_consumption_model(model_params['nn_path'],device)

Expand All @@ -197,21 +197,17 @@ def _nn_bellman_past_shock_consumption(model_graph,model_params, timestep, devi

input = torch.cat((model_graph.ndata['alpha'].unsqueeze(1), model_graph.ndata['wealth'].unsqueeze(1), model_graph.ndata['sigma'].unsqueeze(1), model_graph.ndata['theta'].unsqueeze(1)), dim=1)

#forward pass to get predictions
# Forward pass to get predictions
with torch.no_grad():

pred=estimator(input)

#print(" went forward, writing values")


model_graph.ndata['m'],model_graph.ndata['i_a']=model_graph.ndata['a_table'][torch.arange(model_graph.ndata['a_table'].size(0)),:,torch.argmin(torch.abs(pred[:, 0].unsqueeze(1) - model_graph.ndata['a_table'][:,1,:]), dim=1)].unbind(dim=1)

#print("Cleaning output and checking for violations")

#Clean Consumption
# Clean Consumption
model_graph.ndata['wealth_consumption']=(pred[:,1]*scale).clamp_(min=0)
#print(" violation check")

# Check for violations
# An equation violation occurs when personally shocked, depreciated k + income - consumption - i_a is less than or equal to 0
Expand Down
24 changes: 19 additions & 5 deletions dgl_ptm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache Software License',
'Natural Language :: English',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
]

Expand Down Expand Up @@ -89,16 +87,32 @@ select = [
"N", # PEP8-naming
"UP", # pyupgrade (upgrade syntax to current syntax)
"PLE", # Pylint error https://github.com/charliermarsh/ruff#error-ple
"PLR", # Pylint refactor (e.g. too-many-arguments)
"PLW", # Pylint warning (useless-else-on-loop)
]
extend-select = [
"D401", # First line should be in imperative mood
"D400", # First line should end in a period.
"TID252", # No relative imports (not pep8 compliant)
]
ignore = [
"D100", "D101", "D104", "D105", "D106", "D107", "D203", "D213"
] # docstring style
"PLR2004", # magic value used in comparsion.
"PLR0913", # too many arguments
]

line-length = 88

exclude = ["docs", "build"]

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
target-version = "py39"
target-version = "py311"

[tool.ruff.per-file-ignores]
"tests/**" = ["D"]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.mccabe]
max-complexity = 10
5 changes: 3 additions & 2 deletions dgl_ptm/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dgl_ptm.config import Config


@pytest.fixture
def config_parameters():
return {
Expand Down Expand Up @@ -43,7 +44,7 @@ def test_to_yaml(tmp_path):
cfg = Config()
cfg.to_yaml(tmp_path / "config.yaml")

with open(tmp_path / "config.yaml", "r") as f:
with open(tmp_path / "config.yaml") as f:
cfg_dict = yaml.safe_load(f)
assert cfg_dict["_model_identifier"] == "test"
assert cfg_dict["number_agents"] == 100
Expand All @@ -58,7 +59,7 @@ def test_invalid_fields(config_parameters):
_ = Config.from_dict(config_parameters)

def test_invalid_values(config_parameters):
"""Test that invalid values are not accepted."""
"""Test invalid values."""
config_parameters["number_agents"] = -100
with pytest.raises(ValueError):
_ = Config.from_dict(config_parameters)

0 comments on commit 3c8eae4

Please sign in to comment.