Skip to content

Commit

Permalink
Add add_parent function (#268)
Browse files Browse the repository at this point in the history
* Add initial working and experimental doc

* Added add_parent function under utils/structure

* Tests run successfully

* Test folder updates, and run successfully

* Add_parent added to API file

* remove the tutorial from this PR

* fix tests and mypy errors

---------

Co-authored-by: Nicolas Legrand <nicolas.legrand@cfin.au.dk>
Co-authored-by: LegrandNico <nicolas.legrand@cas.au.dk>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent 3d69771 commit d3d2417
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Utilities for manipulating neural networks.
to_pandas
add_edges
get_input_idxs
add_parent
remove_node

Math
Expand Down
2 changes: 2 additions & 0 deletions pyhgf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .add_edges import add_edges
from .add_parent import add_parent
from .beliefs_propagation import beliefs_propagation
from .fill_categorical_state_node import fill_categorical_state_node
from .get_input_idxs import get_input_idxs
Expand All @@ -9,6 +10,7 @@

__all__ = [
"add_edges",
"add_parent",
"beliefs_propagation",
"fill_categorical_state_node",
"get_input_idxs",
Expand Down
81 changes: 81 additions & 0 deletions pyhgf/utils/add_parent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Author: Louie Mølgaard Hessellund <hessellundlouie@gmail.com>

from typing import Dict, Tuple

from pyhgf.typing import AdjacencyLists, Edges
from pyhgf.utils.add_edges import add_edges


def add_parent(
attributes: Dict, edges: Edges, index: int, coupling_type: str, mean: float
) -> Tuple[Dict, Edges]:
r"""Add a new continuous-state parent node to the attributes and edges of a network.
Parameters
----------
attributes :
The attributes of the existing network.
edges :
The edges of the existing network.
index :
The index of the node you want to connect a new parent node to.
coupling_type :
The type of coupling you want between the existing node and it's new parent.
Can be either `"value"` or `"volatility"`.
mean :
The mean value of the new parent node.
Returns
-------
attributes :
The updated attributes of the existing network.
edges :
The updated edges of the existing network.
"""
# Get index for node to be added
new_node_idx = len(edges)

# Add new node to attributes
attributes[new_node_idx] = {
"mean": mean,
"expected_mean": mean,
"precision": 1.0,
"expected_precision": 1.0,
"volatility_coupling_children": None,
"volatility_coupling_parents": None,
"value_coupling_children": None,
"value_coupling_parents": None,
"tonic_volatility": -4.0,
"tonic_drift": 0.0,
"autoconnection_strength": 1.0,
"observed": 1,
"temp": {
"effective_precision": 0.0,
"value_prediction_error": 0.0,
"volatility_prediction_error": 0.0,
},
}

# Add new AdjacencyList with empty values, to Edges tuple
new_adj_list = AdjacencyLists(
node_type=2,
value_parents=None,
volatility_parents=None,
value_children=None,
volatility_children=None,
coupling_fn=(None,),
)
edges = edges + (new_adj_list,)

# Use add_edges to integrate the altered attributes and edges
attributes, edges = add_edges(
attributes=attributes,
edges=edges,
kind=coupling_type,
parent_idxs=new_node_idx,
children_idxs=index,
)

# Return new attributes and edges
return attributes, edges
23 changes: 21 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyhgf import load_data
from pyhgf.model import Network
from pyhgf.typing import AdjacencyLists
from pyhgf.utils import list_branches, remove_node
from pyhgf.utils import add_parent, list_branches, remove_node


def test_imports():
Expand Down Expand Up @@ -94,9 +94,28 @@ def test_set_update_sequence():
assert len(updates) == 3


def test_add_parent():
"""Test the add_parent function."""
network = (
Network()
.add_nodes(n_nodes=4)
.add_nodes(value_children=2)
.add_nodes(value_children=3)
)
attributes, edges, _ = network.get_network()
new_attributes, new_edges = add_parent(attributes, edges, 1, "volatility", 1.0)

assert len(new_attributes) == 8
assert len(new_edges) == 7

new_attributes, new_edges = add_parent(attributes, edges, 1, "value", 1.0)

assert len(new_attributes) == 8
assert len(new_edges) == 7


def test_remove_node():
"""Test the remove_node function."""
# a standard binary HGF
network = (
Network()
.add_nodes(n_nodes=4)
Expand Down

0 comments on commit d3d2417

Please sign in to comment.