Skip to content

Commit

Permalink
give nodes a weight attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Sep 11, 2024
1 parent 5d88e93 commit 366df10
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
38 changes: 28 additions & 10 deletions ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def __init__(self, root_id: str | UUID):
transitions in an LDP. Any path from the root node to a terminal
node constitutes a complete LDP.
A node may be assigned an arbitrary weight, which will be treated
as a relative probability of sampling that node. For example, if
A(weight=1) and B(w=2) are both children of the same node,
then we treat B as twice as likely as A.
Args:
root_id: A unique identifier for the root node of the tree.
All IDs of transitions added to this tree must begin with
Expand All @@ -151,10 +156,12 @@ def __init__(self, root_id: str | UUID):
self.tree = nx.DiGraph() # the actual tree
self.rev_tree = nx.DiGraph() # the same as self.tree, but with reversed edges

self._add_node(self.root_id, transition=None)
self._add_node(self.root_id, transition=None, weight=1.0)

def _add_node(self, step_id: str, transition: Transition | None) -> None:
self.tree.add_node(step_id, transition=transition)
def _add_node(
self, step_id: str, transition: Transition | None, weight: float
) -> None:
self.tree.add_node(step_id, transition=transition, weight=weight)
self.rev_tree.add_node(step_id)

def _add_edge(self, parent_step_id: str, child_step_id: str) -> None:
Expand All @@ -167,13 +174,19 @@ def get_transition(self, step_id: str) -> Transition:

return cast(Transition, self.tree.nodes[step_id]["transition"])

def add_transition(self, step_id: str, step: Transition) -> None:
def get_weight(self, step_id: str) -> float:
return cast(float, self.tree.nodes[step_id]["weight"])

def add_transition(
self, step_id: str, step: Transition, weight: float = 1.0
) -> None:
"""Add a transition to the tree.
Args:
step_id: A unique identifier for the root node of the tree.
The expected form of the step ID is "{parent step ID}:{step index}".
step: The transition to add.
weight: Weight of the transition. Defaults to 1.0.
"""
root_id, *step_ids = step_id.split(":")
assert (
Expand All @@ -185,7 +198,7 @@ def add_transition(self, step_id: str, step: Transition) -> None:
step_id not in self.tree
), f"Step ID {step_id} already exists in the tree."

self._add_node(step_id, transition=step)
self._add_node(step_id, transition=step, weight=weight)

parent_id = ":".join([root_id, *step_ids[:-1]])
if parent_id in self.tree:
Expand Down Expand Up @@ -259,11 +272,13 @@ def assign_mc_value_estimates(self, discount_factor: float = 1.0) -> None:

if children := list(self.tree.successors(step_id)):
# V_{t+1}(s') = sum_{a'} p(a'|s') * Q_{t+1}(s', a')
# Here we assume p(a'|s') is uniform.
# Here we assume p(a'|s') is uniform over the sampled actions..
# TODO: don't make that assumption where a logprob is available
weights = [self.get_weight(child_id) for child_id in children]
steps = [self.get_transition(child_id) for child_id in children]
v_tp1 = sum(
self.get_transition(child_id).value for child_id in children
) / len(children)
w * step.value for w, step in zip(weights, steps, strict=True)
) / sum(weights)
else:
v_tp1 = 0.0

Expand Down Expand Up @@ -304,16 +319,19 @@ def merge_identical_nodes(
action_str = ""

step_hash = hash((state_hash, join(step.observation), action_str))
step_weight = self.get_weight(step_id)

if step_hash in seen_nodes:
node_remap[step_id] = seen_nodes[step_hash]
merged_node_id = node_remap[step_id] = seen_nodes[step_hash]
# Not sure if this is the fastest way to do this
new_tree.tree.nodes[merged_node_id]["weight"] += step_weight
else:
node_remap[step_id] = seen_nodes[step_hash] = step_id
parent_id = node_remap[":".join(step_id.split(":")[:-1])]

# manually add transitions, since the step_id substring relationship
# will be broken
new_tree._add_node(step_id, transition=step)
new_tree._add_node(step_id, transition=step, weight=step_weight)
new_tree._add_edge(parent_id, step_id)

return new_tree
7 changes: 7 additions & 0 deletions tests/test_data_structures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import networkx as nx
import pytest

from ldp.data_structures import Transition, TransitionTree
Expand Down Expand Up @@ -79,3 +80,9 @@ def test_tree_node_merging():

assert len(tree.tree.nodes) == 5
assert len(merged_tree.tree.nodes) == 3

node_weights = [
merged_tree.get_weight(step_id)
for step_id in nx.topological_sort(merged_tree.tree)
]
assert node_weights == [1, 2, 2] # 1 for the root, 2 for the others

0 comments on commit 366df10

Please sign in to comment.