From 2ba1161751d746ccaf979760c8052d9b4701548b Mon Sep 17 00:00:00 2001 From: Max Zuo Date: Wed, 19 Jun 2024 14:33:12 -0400 Subject: [PATCH 1/5] rustworkx migration. Plot still not there yet --- planetarium/graph.py | 281 ++++++++++++++++++++++++++++------- planetarium/metric.py | 334 +++++++++--------------------------------- planetarium/oracle.py | 318 +++++++++++++++++++++------------------- poetry.lock | 82 ++++++++++- pyproject.toml | 1 + tests/test_metric.py | 188 ++++++++---------------- tests/test_oracle.py | 24 ++- utils.py | 188 +++++++++++++++++++----- 8 files changed, 773 insertions(+), 643 deletions(-) diff --git a/planetarium/graph.py b/planetarium/graph.py index 595c5a0..d6dd426 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -1,7 +1,8 @@ +from typing import Any, Iterable + import abc import enum -import networkx as nx -import typing +import rustworkx as rx class Label(str, enum.Enum): @@ -9,9 +10,72 @@ class Label(str, enum.Enum): PREDICATE = "predicate" -class PlanGraph(nx.MultiDiGraph, metaclass=abc.ABCMeta): +class Scene(str, enum.Enum): + INIT = "init" + GOAL = "goal" + + +class PlanGraphNode: + def __init__( + self, + node: str, + name: str, + label: Label, + typing: str | None = None, + scene: Scene | None = None, + ): + self.node = node + self.name = name + self.label = label + self.typing = typing + self.scene = scene + + def __eq__(self, other: "PlanGraphNode") -> bool: + return ( + isinstance(other, PlanGraphNode) + and self.node == other.node + and self.name == other.name + and self.label == other.label + and self.typing == other.typing + and self.scene == other.scene + ) + + def __hash__(self) -> int: + return hash((self.name, self.label, (*sorted(self.typing),), self.scene)) + + +class PlanGraphEdge: + def __init__( + self, + predicate: str, + position: int | None = None, + scene: Scene | None = None, + ): + self.predicate = predicate + self.position = position + self.scene = scene + + def __eq__(self, other: "PlanGraphEdge") -> bool: + return ( + isinstance(other, PlanGraphEdge) + and self.predicate == other.predicate + and self.position == other.position + and self.scene == other.scene + ) + + def __hash__(self) -> int: + return hash((self.predicate, self.position, self.scene)) + + def __repr__(self) -> str: + return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" + + def __str__(self) -> str: + return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" + + +class PlanGraph(metaclass=abc.ABCMeta): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of rx.PyDiGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -21,8 +85,8 @@ class PlanGraph(nx.MultiDiGraph, metaclass=abc.ABCMeta): def __init__( self, - constants: list[dict[str, typing.Any]], - predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], + predicates: list[dict[str, Any]], domain: str | None = None, ): """ @@ -34,37 +98,83 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. """ - super().__init__() self._constants = constants self._predicates = predicates self._domain = domain + self.graph = rx.PyDiGraph() + for constant in constants: self.add_node( - constant["name"], - name=constant["name"], - typing=constant["typing"], - label=Label.CONSTANT, + PlanGraphNode( + constant["name"], + name=constant["name"], + label=Label.CONSTANT, + typing=constant["typing"], + ) ) - def _add_predicate(self, predicate: dict[str, typing.Any], **kwargs): + @property + def _node_lookup(self) -> dict[str, tuple[int, PlanGraphNode]]: + return {node.node: (index, node) for index, node in enumerate(self.nodes)} + + @property + def nodes(self) -> list[PlanGraphNode]: + return self.graph.nodes() + + @property + def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: + return [ + (self.nodes[u], self.nodes[v], data) + for u, v, data in self.graph.edge_index_map().values() + ] + + def add_node(self, node: PlanGraphNode): + if node in self.nodes: + raise ValueError(f"Node {node.name} already exists in the graph.") + self.graph.add_node(node) + + def add_edge( + self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge + ): + if isinstance(u, PlanGraphNode): + u_index = self.nodes.index(u) + else: + u_index, _ = self._node_lookup[u] + + if isinstance(v, PlanGraphNode): + v_index = self.nodes.index(v) + else: + v_index, _ = self._node_lookup[v] + + self.graph.add_edge(u_index, v_index, edge) + + def _add_predicate( + self, + predicate: dict[str, Any], + scene: Scene | None = None, + ): """ Add a predicate to the plan graph. Parameters: predicate (dict): A dictionary representing the predicate. + scene (Scene, optional): The scene in which the predicate occurs. """ predicate_name = self._build_unique_predicate_name( predicate_name=predicate["typing"], argument_names=predicate["parameters"], ) self.add_node( - predicate_name, - name=predicate_name, - typing=predicate["typing"], - label=Label.PREDICATE, + PlanGraphNode( + predicate_name, + name=predicate_name, + label=Label.PREDICATE, + typing=predicate["typing"], + scene=scene, + ) ) for position, parameter_name in enumerate(predicate["parameters"]): @@ -73,21 +183,77 @@ def _add_predicate(self, predicate: dict[str, typing.Any], **kwargs): self.add_edge( predicate_name, parameter_name, - position=position, - predicate=predicate["typing"], - **kwargs, + PlanGraphEdge( + predicate=predicate["typing"], + position=position, + scene=scene, + ), ) + def in_degree(self, node: str | PlanGraphNode) -> int: + if isinstance(node, PlanGraphNode): + return self.graph.in_degree(self.nodes.index(node)) + else: + return self.graph.in_degree(self._node_lookup[node][0]) + + def out_degree(self, node: str | PlanGraphNode) -> int: + if isinstance(node, PlanGraphNode): + return self.graph.out_degree(self.nodes.index(node)) + else: + return self.graph.out_degree(self._node_lookup[node][0]) + + def predecessors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: + if isinstance(node, PlanGraphNode): + preds = self.graph.predecessors(self.nodes.index(node)) + else: + preds = self.graph.predecessors(self._node_lookup[node][0]) + + return preds + + def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: + if isinstance(node, PlanGraphNode): + succs = self.graph.successors(self.nodes.index(node)) + else: + succs = self.graph.successors(self._node_lookup[node][0]) + + return [self.nodes[succ] for succ in succs] + + def in_edges( + self, node: str | PlanGraphNode + ) -> ( + list[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]] + | list[tuple[PlanGraphNode, PlanGraphNode]] + ): + if isinstance(node, PlanGraphNode): + edges = self.graph.in_edges(self.nodes.index(node)) + else: + edges = self.graph.in_edges(self._node_lookup[node][0]) + + return [(self.nodes[u], self.nodes[v]) for u, v, _ in edges] + + def out_edges( + self, node: str | PlanGraphNode + ) -> ( + list[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]] + | list[tuple[PlanGraphNode, PlanGraphNode]] + ): + if isinstance(node, PlanGraphNode): + edges = self.graph.out_edges(self.nodes.index(node)) + else: + edges = self.graph.out_edges(self._node_lookup[node][0]) + + return [(self.nodes[u], self.nodes[v], data) for u, v, data in edges] + @staticmethod def _build_unique_predicate_name( - predicate_name: str, argument_names: typing.Iterable[str] + predicate_name: str, argument_names: Iterable[str] ) -> str: """ Build a unique name for a predicate based on its name and argument names. Parameters: predicate_name (str): The name of the predicate. - argument_names (typing.Iterable[str]): Sequence of argument names + argument_names (Iterable[str]): Sequence of argument names for the predicate. Returns: @@ -113,9 +279,7 @@ def constants(self) -> dict: Returns: dict: A dictionary containing constant nodes. """ - return dict( - filter(lambda node: node[1]["label"] == Label.CONSTANT, self.nodes.items()) - ) + return {node.name: node for node in self.nodes if node.label == Label.CONSTANT} @property def predicates(self) -> dict: @@ -125,14 +289,29 @@ def predicates(self) -> dict: Returns: dict: A dictionary containing predicate nodes. """ - return dict( - filter(lambda node: node[1]["label"] == Label.PREDICATE, self.nodes.items()) + return {node.name: node for node in self.nodes if node.label == Label.PREDICATE} + + def __eq__(self, other: "PlanGraph") -> bool: + """ + Check if two plan graphs are equal. + + Parameters: + other (PlanGraph): The other plan graph to compare. + + Returns: + bool: True if the plan graphs are equal, False otherwise. + """ + return ( + isinstance(other, PlanGraph) + and set(self.constants) == set(other.constants) + and set(self.predicates) == set(other.predicates) + and self.domain == other.domain ) class SceneGraph(PlanGraph): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of PlanGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -142,8 +321,8 @@ class SceneGraph(PlanGraph): def __init__( self, - constants: list[dict[str, typing.Any]], - predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], + predicates: list[dict[str, Any]], domain: str | None = None, ): """ @@ -164,7 +343,7 @@ def __init__( class ProblemGraph(PlanGraph): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of PlanGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -175,9 +354,9 @@ class ProblemGraph(PlanGraph): def __init__( self, - constants: list[dict[str, typing.Any]], - init_predicates: list[dict[str, typing.Any]], - goal_predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], + init_predicates: list[dict[str, Any]], + goal_predicates: list[dict[str, Any]], domain: str | None = None, ): """ @@ -192,19 +371,25 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. """ - super().__init__(constants, init_predicates + goal_predicates, domain=domain) self._init_predicates = init_predicates self._goal_predicates = goal_predicates for scene, predicates in ( - ("init", init_predicates), - ("goal", goal_predicates), + (Scene.INIT, init_predicates), + (Scene.GOAL, goal_predicates), ): for predicate in predicates: self._add_predicate(predicate, scene=scene) + def __eq__(self, other: "ProblemGraph") -> bool: + return ( + super().__eq__(other) + and set(self.init_predicates) == set(other.init_predicates) + and set(self.goal_predicates) == set(other.goal_predicates) + ) + @property def init_predicates(self) -> dict: """ @@ -213,13 +398,11 @@ def init_predicates(self) -> dict: Returns: dict: A dictionary containing predicate nodes. """ - return dict( - filter( - lambda node: node[1]["label"] == Label.PREDICATE - and node[1]["scene"] == "init", - self.nodes.items(), - ) - ) + return { + node.name: node + for node in self.nodes + if node.label == Label.PREDICATE and node.scene == Scene.INIT + } @property def goal_predicates(self) -> dict: @@ -229,13 +412,11 @@ def goal_predicates(self) -> dict: Returns: dict: A dictionary containing predicate nodes. """ - return dict( - filter( - lambda node: node[1]["label"] == Label.PREDICATE - and node[1]["scene"] == "goal", - self.nodes.items(), - ) - ) + return { + node.name: node + for node in self.nodes + if node.label == Label.PREDICATE and node.scene == Scene.GOAL + } def decompose(self) -> tuple[SceneGraph, SceneGraph]: """ diff --git a/planetarium/metric.py b/planetarium/metric.py index 419c9e4..03fcbd7 100644 --- a/planetarium/metric.py +++ b/planetarium/metric.py @@ -1,205 +1,111 @@ import functools -import networkx as nx -import time +import rustworkx as rx import typing -from planetarium.graph import Label, SceneGraph, ProblemGraph +from planetarium import graph -Node = dict[str, typing.Any] - -def _preserves_mapping(source: Node, target: Node, mapping: dict) -> bool: +def _preserves_mapping( + source: graph.PlanGraphNode, + target: graph.PlanGraphNode, + mapping: dict, +) -> bool: """ Check if a mapping is preserved between the nodes. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. mapping (dict): The mapping between node names. Returns: bool: True if the mapping preserves names, False otherwise. """ return ( - source["label"] == Label.CONSTANT - and target["label"] == Label.CONSTANT - and mapping[source["name"]] == target["name"] + source.label == graph.Label.CONSTANT + and target.label == graph.Label.CONSTANT + and mapping[source.name] == target.name ) -def _same_typing(source: Node, target: Node) -> bool: +def _same_typing(source: graph.PlanGraphNode, target: graph.PlanGraphNode) -> bool: """ Check if the typing of two nodes is the same. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. Returns: bool: True if typings are the same, False otherwise. """ return ( - source["label"] == Label.CONSTANT - and target["label"] == Label.CONSTANT - and source["typing"] == target["typing"] + source.label == graph.Label.CONSTANT + and target.label == graph.Label.CONSTANT + and source.typing == target.typing ) -def _matching(source: Node, target: Node, mapping: typing.Optional[dict]) -> bool: +def _node_matching( + source: graph.PlanGraphNode, + target: graph.PlanGraphNode, + mapping: typing.Optional[dict], +) -> bool: """ Check if two nodes match based on their labels, positions, and typings. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. mapping (Optional[dict]): The mapping between node names. Returns: bool: True if nodes match, False otherwise. """ - match (source["label"], target["label"]): - case (Label.CONSTANT, Label.CONSTANT): + print(source, target, "yeehaw") + match (source.label, target.label): + case (graph.Label.CONSTANT, graph.Label.CONSTANT): return _same_typing(source, target) and ( _preserves_mapping(source, target, mapping) if mapping else True ) - case (Label.PREDICATE, Label.PREDICATE): + case (graph.Label.PREDICATE, graph.Label.PREDICATE): # type of predicate should be the same as well - return source["typing"] == target["typing"] + return source.typing == target.typing case _: return False -def _map( - source: SceneGraph, - target: SceneGraph, - mapping: typing.Optional[dict] = None, -) -> list[dict]: - """ - Find all valid isomorphic mappings between nodes of two scene graphs. - - Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - mapping (Optional[dict]): The initial mapping between node names. - - Returns: - list: A list of dictionaries representing valid mappings. - """ - if not nx.faster_could_be_isomorphic(source, target): - return [] - - matching = functools.partial(_matching, mapping=mapping) - mapper = nx.isomorphism.MultiDiGraphMatcher( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - ) - - if mapper.is_isomorphic(): - mapper.initialize() - return mapper.match() - else: - return [] - - -def _distance( - source: SceneGraph, target: SceneGraph, mapping: typing.Optional[dict] = None -) -> int: - """ - Calculate the graph edit distance between two scene graphs. - - Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - mapping (Optional[dict]): The initial mapping between node names. - - Returns: - int: The graph edit distance. - """ - matching = functools.partial(_matching, mapping=mapping) - - return nx.graph_edit_distance( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - ) - - -def _minimal_mappings( - source: SceneGraph, - target: SceneGraph, - timeout: float | None = None, -) -> typing.Tuple[list, int]: +def _edge_matching( + source: graph.PlanGraphEdge, + target: graph.PlanGraphEdge, + attributes: dict[str, str | int | graph.Scene, graph.Label] = {}, +) -> bool: """ - Calculate the graph edit distance between two scene graphs. + Check if two edges match based on their attributes. Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - max_attempts (int): The maximum number of edit path iterations to - consider. + source (graph.PlanGraphEdge): The source edge. + target (graph.PlanGraphEdge): The target edge. + attributes (dict): The attributes to match. Returns: - tuple: - - list: A list of dictionaries representing valid mappings. - - int: The graph edit distance. - - bool: True if the timeout was reached, False otherwise. + bool: True if edges match, False otherwise. """ - start_time = time.perf_counter() + def _getattr(obj, attr): + v = getattr(obj, attr, attributes[attr]) + if v is None: + v = attributes[attr] + return v - def timed_out() -> bool: - return bool(timeout and time.perf_counter() - start_time > timeout) + return all(_getattr(source, attr) == _getattr(target, attr) for attr in attributes) - # try isomorphism first: - iso_mappings = _map(source, target) - if iso_mappings: - return [[(k, v) for k, v in m.items()] for m in iso_mappings], 0.0, False - - # if it is not isomorphic, try edit distance: - edit_path_gen = nx.similarity.optimize_edit_paths( - source, - target, - node_match=nx.isomorphism.categorical_node_match( - ["label", "typing"], - default=["", ""], - ), - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - strictly_decreasing=False, - timeout=timeout, - ) - paths = [] - bestcost = float("inf") - for vertex_path, _, cost in edit_path_gen: - if bestcost != float("inf") and cost < bestcost: - paths = [] - - paths.append(vertex_path) - bestcost = cost - if timed_out(): - break - - return paths, bestcost, timed_out() - - -def map( - source: ProblemGraph | SceneGraph, - target: ProblemGraph | SceneGraph, +def isomorphic( + source: graph.ProblemGraph | graph.SceneGraph, + target: graph.ProblemGraph | graph.SceneGraph, mapping: typing.Optional[dict] = None, - return_mappings: bool = False, -) -> list[dict]: +) -> bool: """ Find all valid isomorphic mappings between nodes of two scene graphs. @@ -207,37 +113,27 @@ def map( source (ProblemGraph): The source problem graph. target (ProblemGraph): The target problem graph. mapping (Optional[dict]): The initial mapping between node names. - return_mappings (bool): If True, the function will return a list of - dictionaries representing valid mappings. If False, the function - will return a boolean indicating if there is a valid mapping. Returns: - list: A list of dictionaries representing valid mappings. + bool: True if there is a valid mapping, False otherwise. """ - if not nx.faster_could_be_isomorphic(source, target): - return [] if return_mappings else False - - matching = functools.partial(_matching, mapping=mapping) - mapper = nx.isomorphism.MultiDiGraphMatcher( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["scene", "position", "predicate"], - default=["", -1, ""], - ), + node_matching = functools.partial(_node_matching, mapping=mapping) + edge_matching = functools.partial( + _edge_matching, + attributes={"position": -1, "predicate": "", "scene": None}, ) - # print('mapping', isinstance(source, ProblemGraph), source.__class__) - if return_mappings: - return list(mapper.isomorphisms_iter()) if mapper.is_isomorphic() else [] - else: - return mapper.is_isomorphic() + return rx.is_isomorphic( + source.graph, + target.graph, + node_matcher=node_matching, + edge_matcher=edge_matching, + ) def equals( - source: ProblemGraph, - target: ProblemGraph, + source: graph.ProblemGraph, + target: graph.ProblemGraph, is_placeholder: bool = False, ) -> bool: """ @@ -254,110 +150,18 @@ def equals( Returns: bool: True if there is a valid mapping, False otherwise. """ + if source == target: + return True if not is_placeholder: - return nx.utils.graphs_equal(source, target) or map(source, target) + return isomorphic(source, target) else: source_init, source_goal = source.decompose() target_init, target_goal = target.decompose() - if nx.utils.graphs_equal( - source_init, - target_init, - ) and nx.utils.graphs_equal( - source_goal, - target_goal, - ): + if source_init == target_init and source_goal == target_goal: return True - valid_init = map(source_init, target_init) - valid_goal = map(source_goal, target_goal) + valid_init = isomorphic(source_init, target_init) + valid_goal = isomorphic(source_goal, target_goal) return valid_init and valid_goal - - -def distance( - initial_source: SceneGraph, - initial_target: SceneGraph, - goal_source: SceneGraph, - goal_target: SceneGraph, - timeout: float | None = None, -) -> tuple[float, float]: - """ - Calculate the graph edit distance between initial and goal scene graphs. - - Parameters: - initial_source (SceneGraph): The initial source scene graph. - initial_target (SceneGraph): The initial target scene graph. - goal_source (SceneGraph): The goal source scene graph. - goal_target (SceneGraph): The goal target scene graph. - timeout (Optional[float]): The maximum number of seconds to spend. - - Returns: - tuple: - - float: The graph edit distance between the two initial scenes. - - bool: True if the timeout was reached, False otherwise while - calculating initial scene distance. - - float: The graph edit distance between the two goal scenes. - - bool: True if the timeout was reached, False otherwise while - calculating goal scene distance. - """ - - start_time = time.perf_counter() - - def timed_out() -> bool: - return bool(timeout and time.perf_counter() - start_time > timeout) - - def mapping_to_fn(mapping: list) -> typing.Callable[[typing.Any], bool]: - """ - Convert a mapping to a matching function. - - Parameters: - mapping (list): The mapping between node names. - - Returns: - callable: A matching function. - """ - map_dict = {k: v for k, v in mapping} - - def matching(source: Node, target: Node) -> bool: - return ( - source["name"] not in map_dict - or map_dict.get(source["name"]) == target["name"] - ) - - return matching - - if equals( - ProblemGraph.join(initial_source, goal_source), - ProblemGraph.join(initial_target, goal_target), - is_placeholder=False, - ): - return 0.0, False, 0.0, False - - minimal_mappings, init_dist, approx_init_dist = _minimal_mappings( - initial_source, - initial_target, - timeout=timeout, - ) - - goal_dist = float("inf") - for mapping in minimal_mappings: - # use the mapping from the initial graph - edit_path_gen = nx.similarity.optimize_edit_paths( - goal_source, - goal_target, - node_match=mapping_to_fn(mapping), - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - timeout=timeout, - ) - - for _, _, cost in edit_path_gen: - if cost < goal_dist: - goal_dist = cost - if timed_out(): - break - - return init_dist, approx_init_dist, goal_dist, timed_out() diff --git a/planetarium/oracle.py b/planetarium/oracle.py index 5f9c204..338dda7 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -1,15 +1,15 @@ -from typing import Generator +from typing import Any from collections import defaultdict import copy import enum -import networkx as nx +import rustworkx as rx from planetarium import graph -class ReductionNode(tuple, enum.Enum): +class ReducedNode(tuple, enum.Enum): TABLE = ("table", ("blocksworld",)) CLEAR = ("clear", ("blocksworld", "gripper")) ARM = ("arm", ("blocksworld",)) @@ -19,13 +19,32 @@ class ReductionNode(tuple, enum.Enum): ROBBY = ("robby", ("gripper",)) -REDUCTION_NODES = [e.value for e in ReductionNode] +class ReducedGraph(graph.PlanGraph): + def __init__( + self, + constants: list[dict[str, Any]], + predicates: list[dict[str, Any]], + domain: str, + ): + super().__init__(constants, predicates, domain=domain) + + for e in ReducedNode: + predicate, r_node_domains = e.value + if self.domain in r_node_domains: + self.add_node( + graph.PlanGraphNode( + e, + name=predicate, + label=graph.Label.PREDICATE, + typing={predicate}, + ), + ) def _reduce_blocksworld( scene: graph.SceneGraph, validate: bool = True, -) -> tuple[nx.MultiDiGraph, nx.MultiDiGraph]: +) -> ReducedGraph: """Reduces a blocksworld scene graph to a Directed Acyclic Graph. Args: @@ -40,86 +59,73 @@ def _reduce_blocksworld( blocksworld) and if validate is True. Returns: - nx.MultiDiGraph: The reduced scene graph. + ReducedGraph: The reduced problem graph. """ nodes = defaultdict(list) - for node, node_attr in scene.nodes(data=True): - nodes[node_attr["label"]].append((node, node_attr)) - - reduced = nx.MultiDiGraph() - for e in ReductionNode: - predicate, r_node_domains = e.value - if "blocksworld" in r_node_domains: - reduced.add_node( - e, - name=predicate, - label=graph.Label.PREDICATE, - typing={predicate}, - ) + for node in scene.nodes: + nodes[node.label].append(node) + + reduced = ReducedGraph( + constants=scene._constants, + predicates=scene._predicates, + domain="blocksworld", + ) - for obj, attrs in nodes[graph.Label.CONSTANT]: - reduced.add_node(obj, **attrs) if "arm-empty" in scene.predicates: reduced.add_edge( - ReductionNode.CLEAR, - ReductionNode.ARM, - pred="arm-empty", + ReducedNode.CLEAR, + ReducedNode.ARM, + graph.PlanGraphEdge(predicate="arm-empty"), ) pred_nodes = set() - for node, obj, edge_attr in scene.edges(data=True): - pred = edge_attr["predicate"] + for node, obj, edge in scene.edges: + pred = edge.predicate + reduced_edge = graph.PlanGraphEdge(predicate=pred) if node in pred_nodes: continue elif pred == "on-table": - reduced.add_edge(obj, ReductionNode.TABLE, pred="on-table") + reduced.add_edge(obj, ReducedNode.TABLE, reduced_edge) elif pred == "clear": - reduced.add_edge(ReductionNode.CLEAR, obj, pred="clear") + reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) elif pred == "on": - pos = edge_attr["position"] + pos = edge.position other_obj, *_ = [ - v - for _, v, a in scene.out_edges(node, data=True) - if a["position"] == 1 - pos + v.node for _, v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: - reduced.add_edge(obj, other_obj, pred="on") + reduced.add_edge(obj, other_obj, reduced_edge) elif pred == "holding": - reduced.add_edge(obj, ReductionNode.ARM, pred="holding") - + reduced.add_edge(obj, ReducedNode.ARM, reduced_edge) pred_nodes.add(node) if validate: - if not nx.is_directed_acyclic_graph(reduced): + if not rx.is_directed_acyclic_graph(reduced.graph): raise ValueError("Scene graph is not a Directed Acyclic Graph.") for node in reduced.nodes: - if (node != ReductionNode.TABLE and reduced.in_degree(node) > 1) or ( - node != ReductionNode.CLEAR and reduced.out_degree(node) > 1 - ): + if ( + node.node != ReducedNode.TABLE and reduced.in_degree(node.node) > 1 + ) or (node.node != ReducedNode.CLEAR and reduced.out_degree(node.node) > 1): raise ValueError( f"Node {node} has multiple parents/children. (not possible in blocksworld)." ) - if reduced.in_degree(ReductionNode.ARM) == 1: - obj = next(reduced.predecessors(ReductionNode.ARM)) + if reduced.in_degree(ReducedNode.ARM) == 1: + obj = reduced.predecessors(ReducedNode.ARM)[0] if ( - obj != ReductionNode.CLEAR + obj.node != ReducedNode.CLEAR and reduced.in_degree(obj) == 1 - and next(reduced.predecessors(obj)) != ReductionNode.CLEAR + and reduced.predecessors(obj)[0].node != ReducedNode.CLEAR ): raise ValueError("Object on arm is connected to another object.") - reduced._domain = scene._domain - reduced._constants = scene._constants - reduced._predicates = scene._predicates - return reduced def _reduce_gripper( scene: graph.SceneGraph, validate: bool = True, -) -> nx.MultiDiGraph: +) -> ReducedGraph: """Reduces a gripper scene graph to a Directed Acyclic Graph. Args: @@ -128,68 +134,55 @@ def _reduce_gripper( reprsentation is valid and a DAG. Defaults to True. Returns: - nx.SceneGraph: The reduced problem graph. + ReducedGraph: The reduced problem graph. """ nodes = defaultdict(list) - for node, node_attr in scene.nodes(data=True): - nodes[node_attr["label"]].append((node, node_attr)) - - reduced = nx.MultiDiGraph() - for e in ReductionNode: - predicate, r_node_domains = e.value - if "gripper" in r_node_domains: - reduced.add_node( - e, - name=predicate, - label=graph.Label.PREDICATE, - typing={predicate}, - ) + for node in scene.nodes: + nodes[node.label].append(node) - for obj, attrs in nodes[graph.Label.CONSTANT]: - reduced.add_node(obj, **copy.deepcopy(attrs)) + reduced = ReducedGraph( + constants=scene._constants, + predicates=scene._predicates, + domain="gripper", + ) pred_nodes = set() - for node, obj, edge_attr in scene.edges(data=True): - pred = edge_attr["predicate"] + for node, obj, edge in scene.edges: + pred = edge.predicate + reduced_edge = graph.PlanGraphEdge(predicate=pred) if node in pred_nodes: continue elif pred == "at-robby": - reduced.add_edge(ReductionNode.ROBBY, obj, pred=pred) + reduced.add_edge(ReducedNode.ROBBY, obj, reduced_edge) elif pred == "free": - reduced.add_edge(ReductionNode.CLEAR, obj, pred=pred) + reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) elif pred == "ball": - reduced.add_edge(ReductionNode.BALLS, obj, pred=pred) + reduced.add_edge(ReducedNode.BALLS, obj, reduced_edge) elif pred == "gripper": - reduced.add_edge(ReductionNode.GRIPPERS, obj, pred=pred) + reduced.add_edge(ReducedNode.GRIPPERS, obj, reduced_edge) elif pred == "room": - reduced.add_edge(ReductionNode.ROOMS, obj, pred=pred) + reduced.add_edge(ReducedNode.ROOMS, obj, reduced_edge) elif pred in {"carry", "at"}: - pos = edge_attr["position"] + pos = edge.position other_obj, *_ = [ - v - for _, v, a in scene.out_edges(node, data=True) - if a["position"] == 1 - pos + v for _, v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: - reduced.add_edge(obj, other_obj, pred=pred) + reduced.add_edge(obj, other_obj, reduced_edge) pred_nodes.add(node) - if validate and not nx.is_directed_acyclic_graph(reduced): + if validate and not rx.is_directed_acyclic_graph(reduced.graph): raise ValueError("Scene graph is not a Directed Acyclic Graph.") - reduced._domain = scene._domain - reduced._constants = scene._constants - reduced._predicates = scene._predicates - return reduced -def _inflate_blocksworld(scene: nx.MultiDiGraph) -> graph.SceneGraph: +def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: """Respecify a blocksworld scene graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: graph.SceneGraph: The respecified scene graph. @@ -197,55 +190,55 @@ def _inflate_blocksworld(scene: nx.MultiDiGraph) -> graph.SceneGraph: constants = [] predicates = [] - for node, attrs in scene.nodes(data=True): - if not isinstance(node, ReductionNode): - constants.append({"name": node, "typing": attrs["typing"]}) + for node in scene.nodes: + if not isinstance(node.node, ReducedNode): + constants.append({"name": node.node, "typing": node.typing}) for u, v, _ in scene.edges: - if u == ReductionNode.CLEAR and v == ReductionNode.ARM: + if u.node == ReducedNode.CLEAR and v.node == ReducedNode.ARM: predicates.append( { "typing": "arm-empty", "parameters": [], } ) - elif u == ReductionNode.CLEAR: + elif u.node == ReducedNode.CLEAR: predicates.append( { "typing": "clear", - "parameters": [v], + "parameters": [v.node], } ) - elif v == ReductionNode.TABLE: + elif v.node == ReducedNode.TABLE: predicates.append( { "typing": "on-table", - "parameters": [u], + "parameters": [u.node], } ) - elif v == ReductionNode.ARM: + elif v.node == ReducedNode.ARM: predicates.append( { "typing": "holding", - "parameters": [u], + "parameters": [u.node], } ) else: predicates.append( { "typing": "on", - "parameters": [u, v], + "parameters": [u.node, v.node], } ) return graph.SceneGraph(constants, predicates, domain="blocksworld") -def _inflate_gripper(scene: nx.MultiDiGraph) -> graph.SceneGraph: +def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: """Respecify a gripper scene graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: graph.SceneGraph: The respecified scene graph. @@ -253,51 +246,51 @@ def _inflate_gripper(scene: nx.MultiDiGraph) -> graph.SceneGraph: constants = [] predicates = [] - for node, attrs in scene.nodes(data=True): - if not isinstance(node, ReductionNode): - constants.append({"name": node, "typing": attrs["typing"]}) + for node in scene.nodes: + if not isinstance(node.node, ReducedNode): + constants.append({"name": node.node, "typing": node.typing}) - for u, v, attr in scene.edges(data=True): - if u == ReductionNode.ROBBY: + for u, v, edge in scene.edges: + if u.node == ReducedNode.ROBBY: predicates.append( { "typing": "at-robby", - "parameters": [v], + "parameters": [v.node], } ) - elif u == ReductionNode.CLEAR: + elif u.node == ReducedNode.CLEAR: predicates.append( { "typing": "free", - "parameters": [v], + "parameters": [v.node], } ) - elif u == ReductionNode.BALLS: + elif u.node == ReducedNode.BALLS: predicates.append( { "typing": "ball", - "parameters": [v], + "parameters": [v.node], } ) - elif u == ReductionNode.GRIPPERS: + elif u.node == ReducedNode.GRIPPERS: predicates.append( { "typing": "gripper", - "parameters": [v], + "parameters": [v.node], } ) - elif u == ReductionNode.ROOMS: + elif u.node == ReducedNode.ROOMS: predicates.append( { "typing": "room", - "parameters": [v], + "parameters": [v.node], } ) else: predicates.append( { - "typing": attr["pred"], - "parameters": [u, v], + "typing": edge.predicate, + "parameters": [u.node, v.node], } ) @@ -305,12 +298,12 @@ def _inflate_gripper(scene: nx.MultiDiGraph) -> graph.SceneGraph: def _blocksworld_underspecified_blocks( - scene: nx.MultiDiGraph, + scene: ReducedGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: tuple[set[str], set[str], bool]: The set of blocks that are not fully @@ -320,44 +313,47 @@ def _blocksworld_underspecified_blocks( """ top_blocks = set() bottom_blocks = set() - held_block = next(scene.predecessors(ReductionNode.ARM), None) - for node, attrs in scene.nodes(data=True, default=None): - if attrs.get("label") == graph.Label.CONSTANT: + arm_behavior_defined = scene.in_degree(ReducedNode.ARM) > 0 + held_block = ( + scene.predecessors(ReducedNode.ARM)[0] if arm_behavior_defined else None + ) + for node in scene.nodes: + if node.label == graph.Label.CONSTANT: if not scene.in_edges(node) and node != held_block: top_blocks.add(node) if not scene.out_edges(node): bottom_blocks.add(node) - return top_blocks, bottom_blocks, held_block is None + return top_blocks, bottom_blocks, not arm_behavior_defined -def _gripper_get_typed_objects(scene: nx.MultiDiGraph): +def _gripper_get_typed_objects(scene: ReducedGraph): rooms = set() balls = set() grippers = set() - for _, node in scene.out_edges(ReductionNode.ROOMS): + for _, node, _ in scene.out_edges(ReducedNode.ROOMS): rooms.add(node) - for _, node in scene.out_edges(ReductionNode.BALLS): + for _, node, _ in scene.out_edges(ReducedNode.BALLS): balls.add(node) - for _, node in scene.out_edges(ReductionNode.GRIPPERS): + for _, node, _ in scene.out_edges(ReducedNode.GRIPPERS): grippers.add(node) return { - ReductionNode.ROOMS: rooms, - ReductionNode.BALLS: balls, - ReductionNode.GRIPPERS: grippers, + ReducedNode.ROOMS: rooms, + ReducedNode.BALLS: balls, + ReducedNode.GRIPPERS: grippers, } def _gripper_underspecified_blocks( - init: nx.MultiDiGraph, - goal: nx.MultiDiGraph, + init: ReducedGraph, + goal: ReducedGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. Args: - init (nx.MultiDiGraph): The reduced SceneGraph of the initial scene. - goal (nx.MultiDiGraph): The reduced SceneGraph of the goal scene. + init (ReducedGraph): The reduced SceneGraph of the initial scene. + goal (ReducedGraph): The reduced SceneGraph of the goal scene. Returns: tuple[set[str], set[str]]: The set of blocks that are not fully @@ -372,19 +368,19 @@ def _gripper_underspecified_blocks( underspecified_balls = set() underspecified_grippers = set() - for ball in typed[ReductionNode.BALLS]: + for ball in typed[ReducedNode.BALLS]: ball_edges = [ node - for _, node in goal.out_edges(ball) - if not isinstance(node, ReductionNode) + for _, node, _ in goal.out_edges(ball) + if not isinstance(node, ReducedNode) ] if not ball_edges: underspecified_balls.add(ball) - for gripper in typed[ReductionNode.GRIPPERS]: + for gripper in typed[ReducedNode.GRIPPERS]: gripper_edges = [ node for node, _ in goal.in_edges(gripper) - if node == ReductionNode.CLEAR or not isinstance(node, ReductionNode) + if node == ReducedNode.CLEAR or not isinstance(node, ReducedNode) ] if not gripper_edges: underspecified_grippers.add(gripper) @@ -392,18 +388,18 @@ def _gripper_underspecified_blocks( return ( underspecified_balls, underspecified_grippers, - goal.out_degree(ReductionNode.ROBBY) == 0, + goal.out_degree(ReducedNode.ROBBY) == 0, ) def inflate( - scene: nx.MultiDiGraph, + scene: ReducedGraph, domain: str | None = None, ) -> graph.SceneGraph: """Inflate a reduced scene graph to a SceneGraph. Args: - scene (nx.MultiDiGraph): The reduced scene graph to respecify. + scene (ReducedGraph): The reduced scene graph to respecify. domain (str | None, optional): The domain of the scene graph. Defaults to None. @@ -423,14 +419,14 @@ def inflate( def _detached_blocks( nodesA: set[str], nodesB: set[str], - scene: nx.MultiDiGraph, + scene: ReducedGraph, ) -> tuple[set[str], set[str]]: """Finds nodes that are not connected to the rest of the scene graph. Args: nodesA (set[str]): The set of nodes to check. nodesB (set[str]): The set of nodes to check against. - scene (nx.MultiDiGraph): The scene graph to check against. + scene (ReducedGraph): The scene graph to check against. Returns: tuple[set[str], set[str]]: The set of nodes that are not connected to @@ -441,7 +437,9 @@ def _detached_blocks( for a in nodesA: for b in nodesB: - if not nx.has_path(scene, a, b) and not nx.has_path(scene, b, a): + a_index = scene.nodes.index(a) + b_index = scene.nodes.index(b) + if not rx.has_path(scene.graph, a_index, b_index, as_undirected=True): _nodesA.discard(a) _nodesB.discard(b) @@ -449,7 +447,7 @@ def _detached_blocks( def _fully_specify_blocksworld( - scene: nx.MultiDiGraph, + scene: ReducedGraph, ) -> graph.SceneGraph: """Fully specifies a blocksworld scene graph. @@ -457,7 +455,7 @@ def _fully_specify_blocksworld( edges that change the problem represented by the graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: SceneGraph: The fully specified scene graph. @@ -467,29 +465,41 @@ def _fully_specify_blocksworld( top_blocks_, bottom_blocks_ = _detached_blocks(top_blocks, bottom_blocks, scene) for block in top_blocks_: - scene.add_edge(ReductionNode.CLEAR, block, pred="clear") + scene.add_edge( + ReducedNode.CLEAR, + block, + graph.PlanGraphEdge(predicate="clear"), + ) for block in bottom_blocks_: - scene.add_edge(block, ReductionNode.TABLE, pred="on-table") + scene.add_edge( + block, + ReducedNode.TABLE, + graph.PlanGraphEdge(predicate="on-table"), + ) # handle arm if arm_empty and not (top_blocks & bottom_blocks): - scene.add_edge(ReductionNode.CLEAR, ReductionNode.ARM, pred="arm-empty") + scene.add_edge( + ReducedNode.CLEAR, + ReducedNode.ARM, + graph.PlanGraphEdge(predicate="arm-empty"), + ) return scene def _fully_specify_gripper( - init: nx.MultiDiGraph, - goal: nx.MultiDiGraph, -) -> nx.MultiDiGraph: + init: ReducedGraph, + goal: ReducedGraph, +) -> ReducedGraph: """Fully specifies a gripper scene graph. Adds any missing edges to fully specify the scene graph, without adding edges that change the problem represented by the graph. Args: - init (nx.MultiDiGraph): The reduced SceneGraph of the initial scene. - goal (nx.MultiDiGraph): The reduced SceneGraph of the goal scene. + init (ReducedGraph): The reduced SceneGraph of the initial scene. + goal (ReducedGraph): The reduced SceneGraph of the goal scene. Returns: SceneGraph: The fully specified scene graph. @@ -503,7 +513,9 @@ def _fully_specify_gripper( if underspecified_grippers and not underspecified_balls: for gripper in underspecified_grippers: - scene.add_edge(ReductionNode.CLEAR, gripper, pred="free") + scene.add_edge( + ReducedNode.CLEAR, gripper, graph.PlanGraphEdge(predicate="free") + ) return scene @@ -607,7 +619,7 @@ def reduce( graph: graph.SceneGraph, domain: str | None = None, validate: bool = True, -) -> nx.MultiDiGraph: +) -> ReducedGraph: """Reduces a scene graph to a Directed Acyclic Graph. Args: @@ -621,7 +633,7 @@ def reduce( ValueError: If a certain domain is provided but not supported. Returns: - nx.MultiDiGraph: The reduced scene graph. + ReducedGraph: The reduced problem graph. """ domain = domain or graph.domain match domain: diff --git a/poetry.lock b/poetry.lock index dcda5f9..5703667 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "asttokens" @@ -649,6 +649,84 @@ files = [ {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"}, ] +[[package]] +name = "rustworkx" +version = "0.14.2" +description = "A python graph library implemented in Rust" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rustworkx-0.14.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a28a972dc7e0faf03f9f90c5be89328af8a71e609f311840e1a6abc6385edb79"}, + {file = "rustworkx-0.14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:50e682b8fd2f11f9e99c309a01f7ed88a09ad32cda35b92c49835b1c9536ec65"}, + {file = "rustworkx-0.14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6e1c3cf3d265835429074a1ecaa8f9bff327b188e1496a120bf8be8260a46453"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a22c02f74bf391b48ae92f633083d068055f3ed85050e35fe6cda967ff8a825"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:996bad21eacbe124dd1e6abca47dd69ade9db0d4df5dd29197694f5d8e0a8258"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c4647461f05fd9f99bae52002a929e8628d4e5a2e732dbfd7abd00ae5257b7"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:829444876bba1940fa3109998f3b6c9184256d91eea5f0e09d9e9f8f26bb4704"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:987b430dce1351a0c761bd6eedb8f6999f48983c9d4b06bf4b0b9dc45d08be8d"}, + {file = "rustworkx-0.14.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:18ef16f9b6b4f1c0d458fde3f213b78436ac810d61cae60385696b411aa80e1d"}, + {file = "rustworkx-0.14.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe30f1e22e69cbab4182d0017e21c345bf75f142a7b66a828227dd3c654d524c"}, + {file = "rustworkx-0.14.2-cp310-cp310-win32.whl", hash = "sha256:c1fe9f9ed18e270074d3632f6c70cc75c461535d9e76db39d1c0ab712bf64a7a"}, + {file = "rustworkx-0.14.2-cp310-cp310-win_amd64.whl", hash = "sha256:271b36412421d622e9e8cd27e2c6e1bd356e452f979edd41bb32d308df936f47"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c52e34ff4b08d1eaedd2ec906bca4317f4f852b36e4615d372b1ff2bb435ff26"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c97dc0cf7efef033ce50fa570887f97896b0f449c841ec3b127ecb70b3c16c84"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:114bec1606ae31c089ecf52aa511551c545c6ce0746d3e8766082ad450377a2c"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:950fa4ffc1691081587c87c4e869a8f5c7d0672d35ce1ba7c69f758f90bfe8c0"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2134aa9c2065ab6c934017b6909e224e860003eb5dbaa5d2c4e87fff1187459a"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bec8f1f1a6fed3ffbf5348a2b9d700f0b840fed2faa6a5198838d0fa9674a781"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8046991499df7aa984b3d9092e4f013597901c919aaf6fa43147e8550685734"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acb4256fba2c4f5c4ec009f383623b6a7c0a2dbeed1b529d22a193113927364a"}, + {file = "rustworkx-0.14.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:be7be125f9313b58829f7202a66dc166b61bf3c4bbe0c509b8d6902ed0d2da45"}, + {file = "rustworkx-0.14.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5d00f87fce0e48c6d7af4b63ee635188178e91462b52ac900d36ec3184ce92fc"}, + {file = "rustworkx-0.14.2-cp311-cp311-win32.whl", hash = "sha256:521e0f432a94ac9a4c92f30a746b971f7e49476fd128d83d94d4b15a2c17245d"}, + {file = "rustworkx-0.14.2-cp311-cp311-win_amd64.whl", hash = "sha256:8fd20776c0f543340ef96450ba5d9d670b8d74396315f7191303a392844271e0"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7bb37e877653ae4b4d505fc7e5f7847ae06e6822b91cec56e9e851941a6a0ae7"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:230808e3878236464ac00001d8b440382aa6230f0073554ec627580863e380cc"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9a7cb7103ba88e12e3dd8e3b28365cbe971a8c158c1ee770646b2f3fd5cedab0"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2637d0e496f34bac45f926b0aa12fb2e143581208f29a424cfb0eb5a7b5c3bfa"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:692f78ee7f7a60d9c7082a5a26b4eefb697526f195172798389d7009510d84f3"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d5513e93b7c10fbce954771a74fc86d551eb33b9eb318eaa35d7668f9929da"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26076523a1c43e903c633f2375afac28fdbb83b9668bee00fae24d8c672bf6c9"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6de5e2df15c415dfb6e5cb7175239d0862568cb10d028f451d358d101be5d8bf"}, + {file = "rustworkx-0.14.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:21c86c240628abc2123d7d1317647073a738bdfe143c55728261b66bc32806e2"}, + {file = "rustworkx-0.14.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0aa0277b931ca3fdfae07f8999b6a63dc9b89622b2fab820fa6bd95dd1e2e2eb"}, + {file = "rustworkx-0.14.2-cp312-cp312-win32.whl", hash = "sha256:fdc632673d4cd7f1cffe8ce13ea17dc361cf9d0d9f37dfa0888d94bdd5e6c159"}, + {file = "rustworkx-0.14.2-cp312-cp312-win_amd64.whl", hash = "sha256:47768f985f32ac1cd807af816fbd5f6e2433889793afdd838891ae516a95c8a6"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:bfeee5a5be9eb71635a7897a6d2c034b1c01bf876fd15007b8bd4c6eaa8921e2"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4a20434c77f3daab043ab2f96386b5da871ebf15a5495f9ad5b916c3edf03e5c"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d856549e874e064af136f2ce304eb896d32d8865c3e98f8d9e83b577f4c57f1d"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2521a223fb5aab2a14351205456d02bd851e0ec6b0c028f5598fe14f292e881b"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc9e7718eee8295cd5c11a5cf1c0fc7772e9c1dcc3d110edba4c77aad47e7f07"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c23ef82b1d373e07c280b8b6927dbad3953597e34c752e14843ac3df722a621"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a46e0d1398138a75fb909369ffe6dfdcec6bab4d21794e80a9abf45fd2823f68"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16fb941e8f48aea96ee38471a1ae770ec68623864a9b0e4760aabf82c41fc2b"}, + {file = "rustworkx-0.14.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fa92a97e5d35c6901553a812f31ca18305922c0ef06c2d7a9d20fbcc0769b4d1"}, + {file = "rustworkx-0.14.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2e14b2956f2d06f5bb196bcd95f73008245eb6ffa9ee08f86ec369acf0cc04be"}, + {file = "rustworkx-0.14.2-cp38-cp38-win32.whl", hash = "sha256:816d33f69f4189376e1bb8132dea1deef1cd019b25bd281f01b7f394fcadbdad"}, + {file = "rustworkx-0.14.2-cp38-cp38-win_amd64.whl", hash = "sha256:4163f9c2c2d2158e053b30a39f74b0382b4c5a8a43f192c13b736e200b5e2025"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a9b55a8f97799b159da96087176a0e97679dca0b6b5a14b3140aeda7e1050777"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:edb2d67870e41d5a1e16288bca0758580fb6961e8b4dfc337557bdaab81ff016"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f058bdb50c5b0be731b96ffd789c6cec2a99e7f757a57763b2cc56004ed95af6"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff900cb6ae2d4028ffe5a3075cfefa21b14929270844b172595e6de0d2f183eb"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630177a80c68823fb2dd94733298377bd52c2ce3f66758ea0a63966fc2d7c08f"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a79c177a9e4f1c623554e01319fcb7b2a062ae26def7b85dc1f0539b7cdd874"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff956ee6c8224b8225478bb72103d4fc6dd4a247c066da30927776e1b05690"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae61a4c58186b4e428947b92ac2aa0557bcd5071fe8102a542c4337f64091766"}, + {file = "rustworkx-0.14.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5a96b6f96e1bb4e8ee337618d8af0a1aec16c2eda6ffd9968e16d161850d1e77"}, + {file = "rustworkx-0.14.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:49bc143729e0d64a51b0ec6d745665f067116db78ce958d5cbe0389e69c6e73c"}, + {file = "rustworkx-0.14.2-cp39-cp39-win32.whl", hash = "sha256:f11e858a1804d5e276d18d6fc197f797adf5da82cd3382550abeef50196c5a7e"}, + {file = "rustworkx-0.14.2-cp39-cp39-win_amd64.whl", hash = "sha256:b55e75ea35a225d6b0afbdd449665e3b907684347be6a38648bdbfd50e177bf0"}, + {file = "rustworkx-0.14.2.tar.gz", hash = "sha256:bd649322c0649b71fa18cc70a9af027b549560415fa860d6894736029c277b13"}, +] + +[package.dependencies] +numpy = ">=1.16.0,<2" + +[package.extras] +all = ["matplotlib (>=3.0)", "pillow (>=5.4)"] +graphviz = ["pillow (>=5.4)"] +mpl = ["matplotlib (>=3.0)"] + [[package]] name = "scipy" version = "1.11.4" @@ -783,4 +861,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "22b586300ccda99bd77a551075e007e69717bb841af205f3393989b5abfd4a82" +content-hash = "fd5472db53389211f49f7d86b12175d3aa1d9051be1ae942106a1ec0f986fe86" diff --git a/pyproject.toml b/pyproject.toml index 6020a01..3b866c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ networkx = "^3.2.1" pddl = {git = "https://github.com/maxzuo/pddl.git"} numpy = "^1.26.2" scipy = "^1.11.4" +rustworkx = "^0.14.2" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_metric.py b/tests/test_metric.py index ddedd9d..727d48d 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -39,32 +39,44 @@ class TestConstantMatching: @pytest.fixture def source(self): """Fixture for a valid source constant.""" - return {"name": "o1", "typing": ["t1", "t2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "o1", "o1", typing=["t1", "t2"], label=graph.Label.CONSTANT + ) @pytest.fixture def target(self): """Fixture for a valid target constant.""" - return {"name": "c1", "typing": ["t1", "t2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "c1", "c1", typing=["t1", "t2"], label=graph.Label.CONSTANT + ) @pytest.fixture def source_incorrect_label(self): """Fixture for a source constant with an incorrect label.""" - return {"name": "o1", "typing": ["t1", "t2"], "label": graph.Label.PREDICATE} + return graph.PlanGraphNode( + "o1", "o1", typing=["t1", "t2"], label=graph.Label.PREDICATE + ) @pytest.fixture def target_incorrect_label(self): """Fixture for a target constant with an incorrect label.""" - return {"name": "c1", "typing": ["t1", "t2"], "label": graph.Label.PREDICATE} + return graph.PlanGraphNode( + "c1", "c1", typing=["t1", "t2"], label=graph.Label.PREDICATE + ) @pytest.fixture def source_incorrect_typing(self): """Fixture for a source constant with incorrect typing.""" - return {"name": "o1", "typing": ["ty1", "ty2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "o1", "o1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT + ) @pytest.fixture def target_incorrect_typing(self): """Fixture for a target constant with incorrect typing.""" - return {"name": "c1", "typing": ["ty1", "ty2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "c1", "c1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT + ) @pytest.fixture def mapping(self): @@ -78,8 +90,8 @@ def mapping_incorrect(self): def test_correct_matching(self, source, target, mapping): """Test correct matching between source and target constants.""" - assert metric._matching(source, target, None) - assert metric._matching(source, target, mapping) + assert metric._node_matching(source, target, None) + assert metric._node_matching(source, target, mapping) assert metric._same_typing(source, target) assert metric._preserves_mapping(source, target, mapping) @@ -88,10 +100,10 @@ def test_incorrect_label( self, source, target, source_incorrect_label, target_incorrect_label, mapping ): """Test incorrect label matching between source and target constants.""" - assert not metric._matching(source, target_incorrect_label, None) - assert not metric._matching(source_incorrect_label, target, None) - assert not metric._matching(source, target_incorrect_label, mapping) - assert not metric._matching(source_incorrect_label, target, mapping) + assert not metric._node_matching(source, target_incorrect_label, None) + assert not metric._node_matching(source_incorrect_label, target, None) + assert not metric._node_matching(source, target_incorrect_label, mapping) + assert not metric._node_matching(source_incorrect_label, target, mapping) assert not metric._preserves_mapping(source, target_incorrect_label, mapping) assert not metric._preserves_mapping(source_incorrect_label, target, mapping) @@ -103,10 +115,10 @@ def test_incorrect_typing( self, source, target, source_incorrect_typing, target_incorrect_typing, mapping ): """Test incorrect typing between source and target constants.""" - assert not metric._matching(source, target_incorrect_typing, None) - assert not metric._matching(source_incorrect_typing, target, None) - assert not metric._matching(source, target_incorrect_typing, mapping) - assert not metric._matching(source_incorrect_typing, target, mapping) + assert not metric._node_matching(source, target_incorrect_typing, None) + assert not metric._node_matching(source_incorrect_typing, target, None) + assert not metric._node_matching(source, target_incorrect_typing, mapping) + assert not metric._node_matching(source_incorrect_typing, target, mapping) assert metric._preserves_mapping(source, target_incorrect_typing, mapping) assert metric._preserves_mapping(source_incorrect_typing, target, mapping) @@ -116,7 +128,7 @@ def test_incorrect_typing( def test_incorrect_mapping(self, source, target, mapping_incorrect): """Test incorrect mapping between source and target constants.""" - assert not metric._matching(source, target, mapping_incorrect) + assert not metric._node_matching(source, target, mapping_incorrect) assert not metric._preserves_mapping(source, target, mapping_incorrect) @@ -128,42 +140,46 @@ class TestPredicateMatching: @pytest.fixture def source(self): """Fixture for a valid source predicate node.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.PREDICATE, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.PREDICATE, + ) @pytest.fixture def target(self): """Fixture for a valid target predicate node.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.PREDICATE, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.PREDICATE, + ) @pytest.fixture def source_incorrect_label(self): """Fixture for a source predicate node with an incorrect label.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.CONSTANT, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.CONSTANT, + ) @pytest.fixture def target_incorrect_label(self): """Fixture for a target predicate node with an incorrect label.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.CONSTANT, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.CONSTANT, + ) def test_correct_matching(self, source, target): """Test correct matching between source and target predicate nodes.""" - assert metric._matching(source, target, None) + assert metric._node_matching(source, target, None) def test_incorrect_label( self, @@ -173,8 +189,8 @@ def test_incorrect_label( target_incorrect_label, ): """Test incorrect label matching between source and target predicate nodes.""" - assert not metric._matching(source, target_incorrect_label, None) - assert not metric._matching(source_incorrect_label, target, None) + assert not metric._node_matching(source, target_incorrect_label, None) + assert not metric._node_matching(source_incorrect_label, target, None) class TestMetrics: @@ -189,12 +205,8 @@ def test_map(self, problem_string, two_initial_problem_string): initial, goal = problem_graph.decompose() - assert metric._map(initial, initial) != [] - assert metric._map(goal, goal) != [] - assert metric._map(initial, goal) == [] - - assert metric.map(problem_graph, problem_graph, return_mappings=True) != [] - assert metric.map(problem_graph, problem_graph2, return_mappings=True) == [] + assert metric.isomorphic(problem_graph, problem_graph) + assert not metric.isomorphic(problem_graph, problem_graph2) def test_validate(self, problem_string, two_initial_problem_string): """Test the validation function on graph pairs.""" @@ -202,111 +214,35 @@ def test_validate(self, problem_string, two_initial_problem_string): problem_graph2 = pddl.build(two_initial_problem_string) assert metric.equals(problem_graph, problem_graph, is_placeholder=True) - assert not metric.equals(problem_graph, problem_graph2, is_placeholder=True,) - - def test_distance_isomorphic(self, problem_string, renamed_problem_string): - """ - Test the distance function on graph pairs, considering only isomorphic cases. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(renamed_problem_string).decompose() - - # Limiting the test to isomorphic cases to optimize execution time. - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (0.0, False, 0.0, False) - - def test_distance(self, problem_string, wrong_problem_string): - """ - Test the distance function on graph pairs. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(wrong_problem_string).decompose() - - assert metric._distance(initial, initial) == 0.0 - assert metric.distance(initial, initial2, goal, goal2) == ( - 0.0, - False, - 2.0, - False, + assert not metric.equals( + problem_graph, + problem_graph2, + is_placeholder=True, ) - def test_wrong_initial_scene( - self, - problem_string, - wrong_initial_problem_string, - ): - """ - Test the distance function on graph pairs. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(wrong_initial_problem_string).decompose() - - # This function should timeout, so the value will be an approximation - assert metric._distance(initial, initial) == 0.0 - - init_distance, approx_init, _, _ = metric.distance( - initial, - initial2, - goal, - goal2, - timeout=2.0, - ) - - assert init_distance < 25.0 - assert approx_init - def test_swap(self, swap_problem_string, wrong_swap_problem_string): """ Test the distance function on graph pairs. """ swap_problem = pddl.build(swap_problem_string) - initial, goal = swap_problem.decompose() wrong_swap = pddl.build(wrong_swap_problem_string) - initial2, goal2 = wrong_swap.decompose() # Test validate assert metric.equals(swap_problem, swap_problem, is_placeholder=False) assert not metric.equals(swap_problem, wrong_swap, is_placeholder=False) assert metric.equals(swap_problem, wrong_swap, is_placeholder=True) - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (0.0, False, 2.0, False) - def test_move(self, move_problem_string, wrong_move_problem_string): """ Test the distance function on graph pairs. """ move_problem = pddl.build(move_problem_string) - initial, goal = move_problem.decompose() wrong_move = pddl.build(wrong_move_problem_string) - initial2, goal2 = wrong_move.decompose() # Test validate assert metric.equals(move_problem, move_problem, is_placeholder=True) assert not metric.equals(move_problem, wrong_move, is_placeholder=True) - # Limiting the test to isomorphic cases to optimize execution time. - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (2.0, False, 0.0, False) - def test_blocksworld_equivalence( self, blocksworld_fully_specified, diff --git a/tests/test_oracle.py b/tests/test_oracle.py index b7065d8..6523af2 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -2,8 +2,6 @@ from planetarium import graph, oracle, pddl -import networkx as nx - @pytest.fixture def blocksworld_fully_specified(): @@ -851,7 +849,7 @@ def gripper_missing_typing(): """ -def reduce_and_respecify(scene: graph.SceneGraph) -> bool: +def reduce_and_inflate(scene: graph.SceneGraph) -> bool: """Respecify a scene and check if it is equal to the original. Args: @@ -862,7 +860,7 @@ def reduce_and_respecify(scene: graph.SceneGraph) -> bool: """ reduced = oracle.reduce(scene, domain=scene.domain) respecified = oracle.inflate(reduced, domain=scene.domain) - return nx.utils.graphs_equal(scene, respecified) + return scene == respecified class TestBlocksworldOracle: @@ -944,7 +942,7 @@ def test_missing_ontables_and_clears(self, blocksworld_underspecified): is_placeholder=False, ) - def test_respecify( + def test_inflate( self, blocksworld_fully_specified, blocksworld_missing_clears, @@ -954,7 +952,7 @@ def test_respecify( blocksworld_holding, ): """ - Test the respecify function. + Test the inflate function. """ descs = [ @@ -968,8 +966,8 @@ def test_respecify( for desc in descs: init, goal = pddl.build(desc).decompose() - assert reduce_and_respecify(init) - assert reduce_and_respecify(goal) + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) def test_invalid( self, @@ -1021,14 +1019,14 @@ def test_fully_specified( ) assert not oracle.is_fully_specified(problem, is_placeholder=False) - def test_respecify(self, gripper_fully_specified): + def test_inflate(self, gripper_fully_specified): """ - Test the respecify function. + Test the inflate function. """ init, goal = pddl.build(gripper_fully_specified).decompose() - assert reduce_and_respecify(init) - assert reduce_and_respecify(goal) + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) def test_underspecified( self, @@ -1051,7 +1049,7 @@ def test_invalid(self, gripper_invalid): class TestUnsupportedDomain: - def test_reduce_and_respecify(self, gripper_fully_specified): + def test_reduce_and_inflate(self, gripper_fully_specified): problem = pddl.build(gripper_fully_specified) init, goal = problem.decompose() diff --git a/utils.py b/utils.py index 609b2bb..1d6a3a1 100644 --- a/utils.py +++ b/utils.py @@ -3,8 +3,10 @@ import yaml from datasets import Dataset +import matplotlib.collections import matplotlib.pyplot as plt -import networkx as nx +import numpy as np +import rustworkx as rx from planetarium import graph, oracle import llm_planner as llmp @@ -27,22 +29,21 @@ def apply_template( Returns: list[dict[str, str]]: Problem prompt. """ - return ( + return [ + { + "role": "user", + "content": f"{problem_prompt} {problem.natural_language} " + + f"{domain_prompt}\n{problem.domain}\n", + }, + ] + ( [ - { - "role": "user", - "content": f"{problem_prompt} {problem.natural_language} " - + f"{domain_prompt}\n{problem.domain}\n", - }, - ] - + ([ { "role": "assistant", "content": " " + problem.problem, }, ] if include_answer - else []) + else [] ) @@ -50,39 +51,158 @@ def strip(text: str, bos_token: str, eos_token: str) -> str: return text.removeprefix(bos_token) + eos_token -def plot(graph: graph.SceneGraph, already_reduced: bool = False): - """Plot a graph representation of the PDDL description. +def _layout(G: graph.PlanGraph, scale: float = 1.0): + """Position nodes in layers of straight lines. + + Source: https://github.com/networkx/networkx/blob/main/networkx/drawing/layout.py Args: - graph (nx.MultiDiGraph): The graph to plot. + G (rx.PyDiGraph): A directed graph. + scale (float, optional): Scale factor for positions. Defaults to 1. + + Returns: + dict: A dictionary of positions keyed by node. + """ - if not already_reduced: - graph = oracle.reduce(graph, validate=False) - for layer, nodes in enumerate(nx.topological_generations(graph)): - for node in nodes: - graph.nodes[node]["layer"] = layer - pos = nx.multipartite_layout( - graph, - align="horizontal", - subset_key="layer", - scale=-1, + + center = np.zeros(2) + if len(G.nodes) == 0: + return {} + + layers = rx.topological_generations(G.graph) + + pos = None + nodes = [] + width = len(layers) + for i, layer in enumerate(layers): + height = len(layer) + xs = np.repeat(i, height) + ys = np.arange(0, height, dtype=float) + offset = ((width - 1) / 2, (height - 1) / 2) + layer_pos = np.column_stack([xs, ys]) - offset + if pos is None: + pos = layer_pos + else: + pos = np.concatenate([pos, layer_pos]) + nodes.extend(layer) + + # Rescale + pos -= pos.mean(axis=0) + lim = np.abs(pos).max() + if lim > 0: + pos *= scale / lim + pos += center + # horizontal + pos = pos[:, ::-1] # swap x and y coords + pos = dict(zip(nodes, pos)) + return pos + + +def _draw( + G: graph.PlanGraph, + pos: dict, + ax: plt.Axes, + node_size: int = 300, + node_color="#1f78b4", + node_shape="o", + alpha=None, + cmap=None, + vmin=None, + vmax=None, + linewidths=None, + edgecolors=None, + label=None, + font_size: int = 12, + font_color="k", + font_family="sans-serif", + font_weight="normal", + bbox=None, + horizontalalignment="center", + verticalalignment="center", + clip_on=True, +): + """Draw the graph G with Matplotlib. + + Source: Source: https://github.com/networkx/networkx/blob/main/networkx/drawing/nx_pylab.py + """ + xy = np.asarray([pos[v] for v in G.graph.node_indices()]) + nodes_collection = ax.scatter( + xy[:, 0], + xy[:, 1], + s=node_size, + c=node_color, + marker=node_shape, + cmap=cmap, + vmin=vmin, + vmax=vmax, + alpha=alpha, + linewidths=linewidths, + edgecolors=edgecolors, + ) + nodes_collection.set_zorder(2) # nodes go on top of edges + # Add node labels: + labels = [node.node for node in G.nodes] + for n, label in enumerate(labels): + (x, y) = pos[n] + if not isinstance(label, str): + label = str(label) # this makes "1" and 1 labeled the same + ax.text( + x, + y, + label, + size=font_size, + color=font_color, + family=font_family, + weight=font_weight, + alpha=alpha, + horizontalalignment=horizontalalignment, + verticalalignment=verticalalignment, + transform=ax.transData, + bbox=bbox, + clip_on=clip_on, + ) + + # plot edges + edge_pos = np.asarray([(pos[u], pos[v]) for u, v, _ in G.graph.edge_index_map().values()]) + + edge_collection = matplotlib.collections.LineCollection( + edge_pos, + colors="k", + linewidths=1., + antialiaseds=(1,), + linestyle="solid", + alpha=alpha, ) - fig = plt.figure() + edge_collection.set_zorder(1) # edges go behind nodes + ax.add_collection(edge_collection) + + ax.tick_params( + axis="both", + which="both", + bottom=False, + left=False, + labelbottom=False, + labelleft=False, + ) + ax.set_axis_off() - nx.draw_networkx_nodes(graph, pos, ax=fig.gca()) - nx.draw_networkx_labels(graph, pos, ax=fig.gca()) - from collections import Counter - edge_counts = Counter(map(str, graph.edges())) +def plot(graph: graph.PlanGraph, reduce: bool = False): + """Plot a graph representation of the PDDL description. - curved_edges = [edge for edge in graph.edges() if edge_counts[str(edge)] > 1] - straight_edges = list(set(graph.edges()) - set(curved_edges)) + Args: + graph (graph.PlanGraph): The graph to plot. + already_reduced (bool, optional): Whether the graph is already reduced. + Defaults to False. + """ + if reduce: + graph = oracle.reduce(graph, validate=False) + # TODO: rx has no multipartite layout + pos = _layout(graph, scale=-1) - arc_rad = 0.25 - print(edge_counts) - nx.draw_networkx_edges(graph, pos, ax=fig.gca(), edgelist=straight_edges) - nx.draw_networkx_edges(graph, pos, ax=fig.gca(), edgelist=curved_edges, connectionstyle=f'arc3, rad = {arc_rad}') + fig = plt.figure() + _draw(graph, pos, ax=fig.gca()) return fig From 343cf1930820872982c452aee6b73ce270db3013 Mon Sep 17 00:00:00 2001 From: Max Zuo Date: Sun, 23 Jun 2024 00:15:49 -0400 Subject: [PATCH 2/5] rustworkx oracle --- evaluate.py | 8 +- planetarium/graph.py | 187 ++++++++++++------ planetarium/metric.py | 1 - planetarium/oracle.py | 448 +++++++++++++++++++++++++++--------------- tests/test_graph.py | 8 +- tests/test_metric.py | 2 - tests/test_oracle.py | 95 ++------- tests/test_pddl.py | 51 +++++ 8 files changed, 495 insertions(+), 305 deletions(-) diff --git a/evaluate.py b/evaluate.py index 79dda15..e89b320 100644 --- a/evaluate.py +++ b/evaluate.py @@ -210,7 +210,7 @@ def result(): problem_graph = pddl.build(problem_pddl) init, _ = problem_graph.decompose() - if len(llm_problem_graph._constants) != len(problem_graph._constants): + if len(llm_problem_graph.constants) != len(problem_graph.constants): resolved = True return result() @@ -255,8 +255,8 @@ def full_equivalence( bool: True if the scene graphs are equivalent, False otherwise. """ return metric.equals( - oracle.fully_specify(source), - oracle.fully_specify(target), + oracle.fully_specify(source, return_reduced=True), + oracle.fully_specify(target, return_reduced=True), is_placeholder=is_placeholder, ) @@ -612,8 +612,6 @@ def main(config_path: str): # Get LLM output first problems = load_ungenerated_problems(config, config_str, problem_ids) - print(config_str) - print(len(problems)) # if len(problems) > 0: # if config["evaluate"]["model"]["type"] == "openai": # generate_openai(problems, config, config_str) diff --git a/planetarium/graph.py b/planetarium/graph.py index d6dd426..0f58473 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -43,6 +43,12 @@ def __eq__(self, other: "PlanGraphNode") -> bool: def __hash__(self) -> int: return hash((self.name, self.label, (*sorted(self.typing),), self.scene)) + def __repr__(self) -> str: + return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" + + def __str__(self) -> str: + return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" + class PlanGraphEdge: def __init__( @@ -86,7 +92,6 @@ class PlanGraph(metaclass=abc.ABCMeta): def __init__( self, constants: list[dict[str, Any]], - predicates: list[dict[str, Any]], domain: str | None = None, ): """ @@ -94,16 +99,12 @@ def __init__( Parameters: constants (list): List of dictionaries representing constants. - predicates (list): List of dictionaries representing predicates. domain (str, optional): The domain of the scene graph. Defaults to None. """ super().__init__() - self._constants = constants - self._predicates = predicates self._domain = domain - self.graph = rx.PyDiGraph() for constant in constants: @@ -133,9 +134,30 @@ def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: def add_node(self, node: PlanGraphNode): if node in self.nodes: - raise ValueError(f"Node {node.name} already exists in the graph.") + raise ValueError(f"Node {node} already exists in the graph.") self.graph.add_node(node) + def has_edge( + self, + u: str | PlanGraphNode, + v: str | PlanGraphNode, + edge: PlanGraphEdge | None = None, + ) -> bool: + if isinstance(u, PlanGraphNode): + u_index = self.nodes.index(u) + else: + u_index, _ = self._node_lookup[u] + + if isinstance(v, PlanGraphNode): + v_index = self.nodes.index(v) + else: + v_index, _ = self._node_lookup[v] + + if edge: + return (u_index, v_index, edge) in self.graph.edge_index_map().values() + else: + return self.graph.has_edge(u_index, v_index) + def add_edge( self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge ): @@ -178,7 +200,7 @@ def _add_predicate( ) for position, parameter_name in enumerate(predicate["parameters"]): - if parameter_name not in self.constants: + if parameter_name not in [node.name for node in self.constant_nodes]: raise ValueError(f"Parameter {parameter_name} not found in constants.") self.add_edge( predicate_name, @@ -220,29 +242,23 @@ def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: def in_edges( self, node: str | PlanGraphNode - ) -> ( - list[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]] - | list[tuple[PlanGraphNode, PlanGraphNode]] - ): + ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: if isinstance(node, PlanGraphNode): edges = self.graph.in_edges(self.nodes.index(node)) else: edges = self.graph.in_edges(self._node_lookup[node][0]) - return [(self.nodes[u], self.nodes[v]) for u, v, _ in edges] + return [(self.nodes[u], edge) for u, _, edge in edges] def out_edges( self, node: str | PlanGraphNode - ) -> ( - list[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]] - | list[tuple[PlanGraphNode, PlanGraphNode]] - ): + ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: if isinstance(node, PlanGraphNode): edges = self.graph.out_edges(self.nodes.index(node)) else: edges = self.graph.out_edges(self._node_lookup[node][0]) - return [(self.nodes[u], self.nodes[v], data) for u, v, data in edges] + return [(self.nodes[v], edge) for _, v, edge in edges] @staticmethod def _build_unique_predicate_name( @@ -272,24 +288,45 @@ def domain(self) -> str | None: return self._domain @property - def constants(self) -> dict: - """ - Get a dictionary of constant nodes in the scene graph. + def constant_nodes(self) -> list[PlanGraphNode]: + """Get a list of constant nodes in the scene graph. Returns: - dict: A dictionary containing constant nodes. + list[PlanGraphNode]: A list of constant nodes. """ - return {node.name: node for node in self.nodes if node.label == Label.CONSTANT} + return [node for node in self.nodes if node.label == Label.CONSTANT] @property - def predicates(self) -> dict: - """ - Get a dictionary of predicate nodes in the scene graph. + def constants(self) -> list[dict[str, Any]]: + return [ + {"name": constant.name, "typing": constant.typing} + for constant in self.constant_nodes + ] + + @property + def predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the scene graph. Returns: - dict: A dictionary containing predicate nodes. + list[PlanGraphNode]: A list of predicate nodes. """ - return {node.name: node for node in self.nodes if node.label == Label.PREDICATE} + return [node for node in self.nodes if node.label == Label.PREDICATE] + + @property + def predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } + ) + + return predicates def __eq__(self, other: "PlanGraph") -> bool: """ @@ -303,8 +340,8 @@ def __eq__(self, other: "PlanGraph") -> bool: """ return ( isinstance(other, PlanGraph) - and set(self.constants) == set(other.constants) - and set(self.predicates) == set(other.predicates) + and set(self.nodes) == set(other.nodes) + and set(self.edges) == set(other.edges) and self.domain == other.domain ) @@ -324,6 +361,7 @@ def __init__( constants: list[dict[str, Any]], predicates: list[dict[str, Any]], domain: str | None = None, + scene: Scene | None = None, ): """ Initialize the SceneGraph instance. @@ -333,12 +371,15 @@ def __init__( predicates (list): List of dictionaries representing predicates. domain (str, optional): The domain of the scene graph. Defaults to None. + scene (str, optional): The scene of the scene graph. """ - super().__init__(constants, predicates, domain=domain) + super().__init__(constants, domain=domain) + + self.scene = scene for predicate in predicates: - self._add_predicate(predicate) + self._add_predicate(predicate, scene=scene) class ProblemGraph(PlanGraph): @@ -371,10 +412,7 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. """ - super().__init__(constants, init_predicates + goal_predicates, domain=domain) - - self._init_predicates = init_predicates - self._goal_predicates = goal_predicates + super().__init__(constants, domain=domain) for scene, predicates in ( (Scene.INIT, init_predicates), @@ -386,37 +424,67 @@ def __init__( def __eq__(self, other: "ProblemGraph") -> bool: return ( super().__eq__(other) - and set(self.init_predicates) == set(other.init_predicates) - and set(self.goal_predicates) == set(other.goal_predicates) + and set(self.init_predicate_nodes) == set(other.init_predicate_nodes) + and set(self.goal_predicate_nodes) == set(other.goal_predicate_nodes) ) @property - def init_predicates(self) -> dict: - """ - Get a dictionary of predicate nodes in the initial scene graph. + def init_predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the initial scene. Returns: - dict: A dictionary containing predicate nodes. + list[PlanGraphNode]: A list of predicate nodes in the initial scene. """ - return { - node.name: node + return [ + node for node in self.nodes if node.label == Label.PREDICATE and node.scene == Scene.INIT - } + ] @property - def goal_predicates(self) -> dict: - """ - Get a dictionary of predicate nodes in the initial scene graph. + def goal_predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the goal scene. Returns: - dict: A dictionary containing predicate nodes. + list[PlanGraphNode]: A list of predicate nodes in the goal scene. """ - return { - node.name: node + return [ + node for node in self.nodes if node.label == Label.PREDICATE and node.scene == Scene.GOAL - } + ] + + @property + def init_predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.init_predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } + ) + + return predicates + + @property + def goal_predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.goal_predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } + ) + + return predicates def decompose(self) -> tuple[SceneGraph, SceneGraph]: """ @@ -425,16 +493,19 @@ def decompose(self) -> tuple[SceneGraph, SceneGraph]: Returns: tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. """ + init_scene = SceneGraph( - constants=self._constants, - predicates=self._init_predicates, + constants=self.constants, + predicates=self.init_predicates, domain=self.domain, + scene=Scene.INIT, ) goal_scene = SceneGraph( - constants=self._constants, - predicates=self._goal_predicates, + constants=self.constants, + predicates=self.goal_predicates, domain=self.domain, + scene=Scene.GOAL, ) return init_scene, goal_scene @@ -452,8 +523,8 @@ def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": ProblemGraph: The combined problem graph. """ return ProblemGraph( - constants=init._constants, - init_predicates=init._predicates, - goal_predicates=goal._predicates, + constants=init.constants, + init_predicates=init.predicates, + goal_predicates=goal.predicates, domain=init.domain, ) diff --git a/planetarium/metric.py b/planetarium/metric.py index 03fcbd7..12fe0f6 100644 --- a/planetarium/metric.py +++ b/planetarium/metric.py @@ -62,7 +62,6 @@ def _node_matching( Returns: bool: True if nodes match, False otherwise. """ - print(source, target, "yeehaw") match (source.label, target.label): case (graph.Label.CONSTANT, graph.Label.CONSTANT): return _same_typing(source, target) and ( diff --git a/planetarium/oracle.py b/planetarium/oracle.py index 338dda7..5d21d2a 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -9,46 +9,116 @@ from planetarium import graph -class ReducedNode(tuple, enum.Enum): - TABLE = ("table", ("blocksworld",)) - CLEAR = ("clear", ("blocksworld", "gripper")) - ARM = ("arm", ("blocksworld",)) - ROOMS = ("rooms", ("gripper",)) - BALLS = ("balls", ("gripper",)) - GRIPPERS = ("grippers", ("gripper",)) - ROBBY = ("robby", ("gripper",)) +class ReducedNode(str, enum.Enum): + TABLE = "table" + CLEAR = "clear" + ARM = "arm" + ROOMS = "room" + BALLS = "ball" + GRIPPERS = "gripper" + ROBBY = "at-robby" + FREE = "free" + + +BlocksworldReducedNodes = { + ReducedNode.TABLE, + ReducedNode.CLEAR, + ReducedNode.ARM, +} + +GripperReducedNodes = { + ReducedNode.ROOMS, + ReducedNode.BALLS, + ReducedNode.GRIPPERS, + ReducedNode.ROBBY, + ReducedNode.FREE, +} + +ReducedNodes = { + "blocksworld": BlocksworldReducedNodes, + "gripper": GripperReducedNodes, +} + + +class ReducedSceneGraph(graph.PlanGraph): + def __init__( + self, + constants: list[dict[str, Any]], + domain: str, + scene: graph.Scene | None = None, + ): + super().__init__(constants, domain=domain) + self.scene = scene + + for e in ReducedNodes[domain]: + predicate = e.value + self.add_node( + graph.PlanGraphNode( + e, + name=predicate, + label=graph.Label.PREDICATE, + typing={predicate}, + ), + ) -class ReducedGraph(graph.PlanGraph): +class ReducedProblemGraph(graph.PlanGraph): def __init__( self, constants: list[dict[str, Any]], - predicates: list[dict[str, Any]], domain: str, ): - super().__init__(constants, predicates, domain=domain) + super().__init__(constants, domain=domain) + + for e in ReducedNodes[domain]: + predicate = e.value + self.add_node( + graph.PlanGraphNode( + e, + name=predicate, + label=graph.Label.PREDICATE, + typing={predicate}, + ), + ) + + def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]: + init = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.INIT) + goal = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.GOAL) + + for u, v, edge in self.edges: + edge = copy.deepcopy(edge) + if edge.scene == graph.Scene.INIT: + init.add_edge(u, v, edge) + elif edge.scene == graph.Scene.GOAL: + goal.add_edge(u, v, edge) + + return init, goal - for e in ReducedNode: - predicate, r_node_domains = e.value - if self.domain in r_node_domains: - self.add_node( - graph.PlanGraphNode( - e, - name=predicate, - label=graph.Label.PREDICATE, - typing={predicate}, - ), - ) + @staticmethod + def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGraph": + problem = ReducedProblemGraph(init.constants, domain=init.domain) + + for u, v, edge in init.edges: + edge = copy.deepcopy(edge) + problem.add_edge(u, v, edge) + edge.scene = graph.Scene.INIT + for u, v, edge in goal.edges: + edge = copy.deepcopy(edge) + edge.scene = graph.Scene.GOAL + problem.add_edge(u, v, edge) + + return problem def _reduce_blocksworld( - scene: graph.SceneGraph, + scene: graph.SceneGraph | graph.ProblemGraph, validate: bool = True, -) -> ReducedGraph: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a blocksworld scene graph to a Directed Acyclic Graph. Args: - problem (graph.SceneGraph): The scene graph to reduce. + problem (graph.SceneGraph | graph.ProblemGraph): The scene graph to + reduce. validate (bool, optional): Whether or not to validate if the reduced reprsentation is valid. Defaults to True. @@ -66,33 +136,42 @@ def _reduce_blocksworld( for node in scene.nodes: nodes[node.label].append(node) - reduced = ReducedGraph( - constants=scene._constants, - predicates=scene._predicates, - domain="blocksworld", - ) - - if "arm-empty" in scene.predicates: - reduced.add_edge( - ReducedNode.CLEAR, - ReducedNode.ARM, - graph.PlanGraphEdge(predicate="arm-empty"), + if isinstance(scene, graph.ProblemGraph): + reduced = ReducedProblemGraph(constants=scene.constants, domain="blocksworld") + elif isinstance(scene, graph.SceneGraph): + reduced = ReducedSceneGraph( + constants=scene.constants, + domain="blocksworld", + scene=scene.scene, ) + else: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") + + for pred_node in scene.predicate_nodes: + if pred_node.typing == "arm-empty": + reduced.add_edge( + ReducedNode.CLEAR, + ReducedNode.ARM, + graph.PlanGraphEdge( + predicate="arm-empty", + scene=pred_node.scene, + ), + ) pred_nodes = set() for node, obj, edge in scene.edges: pred = edge.predicate - reduced_edge = graph.PlanGraphEdge(predicate=pred) + reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) if node in pred_nodes: continue - elif pred == "on-table": + if pred == "on-table": reduced.add_edge(obj, ReducedNode.TABLE, reduced_edge) elif pred == "clear": reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) elif pred == "on": pos = edge.position other_obj, *_ = [ - v.node for _, v, a in scene.out_edges(node) if a.position == 1 - pos + v.node for v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: reduced.add_edge(obj, other_obj, reduced_edge) @@ -101,31 +180,51 @@ def _reduce_blocksworld( pred_nodes.add(node) if validate: - if not rx.is_directed_acyclic_graph(reduced.graph): - raise ValueError("Scene graph is not a Directed Acyclic Graph.") - for node in reduced.nodes: - if ( - node.node != ReducedNode.TABLE and reduced.in_degree(node.node) > 1 - ) or (node.node != ReducedNode.CLEAR and reduced.out_degree(node.node) > 1): - raise ValueError( - f"Node {node} has multiple parents/children. (not possible in blocksworld)." - ) - if reduced.in_degree(ReducedNode.ARM) == 1: - obj = reduced.predecessors(ReducedNode.ARM)[0] - if ( - obj.node != ReducedNode.CLEAR - and reduced.in_degree(obj) == 1 - and reduced.predecessors(obj)[0].node != ReducedNode.CLEAR - ): - raise ValueError("Object on arm is connected to another object.") + if isinstance(reduced, ReducedProblemGraph): + init, goal = reduced.decompose() + _validate_blocksworld(init) + _validate_blocksworld(goal) + elif isinstance(reduced, ReducedSceneGraph): + _validate_blocksworld(reduced) return reduced +def _validate_blocksworld(scene: graph.SceneGraph): + """Validates a blocksworld scene graph. + + Args: + scene (graph.SceneGraph): The scene graph to validate. + + Raises: + ValueError: If the scene graph is not a Directed Acyclic Graph. + ValueError: If a node has multiple parents/children (not allowed in + blocksworld). + ValueError: If an object on the arm is connected to another object. + """ + if not rx.is_directed_acyclic_graph(scene.graph): + raise ValueError("Scene graph is not a Directed Acyclic Graph.") + for node in scene.nodes: + if ( + node.node != ReducedNode.TABLE and scene.in_degree(node.node) > 1 + ) or (node.node != ReducedNode.CLEAR and scene.out_degree(node.node) > 1): + raise ValueError( + f"Node {node} has multiple parents/children. (not possible in blocksworld)." + ) + if scene.in_degree(ReducedNode.ARM) == 1: + obj = scene.predecessors(ReducedNode.ARM)[0] + if ( + obj.node != ReducedNode.CLEAR + and scene.in_degree(obj) == 1 + and scene.predecessors(obj)[0].node != ReducedNode.CLEAR + ): + raise ValueError("Object on arm is connected to another object.") + + def _reduce_gripper( - scene: graph.SceneGraph, + scene: graph.SceneGraph | graph.ProblemGraph, validate: bool = True, -) -> ReducedGraph: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a gripper scene graph to a Directed Acyclic Graph. Args: @@ -140,22 +239,27 @@ def _reduce_gripper( for node in scene.nodes: nodes[node.label].append(node) - reduced = ReducedGraph( - constants=scene._constants, - predicates=scene._predicates, - domain="gripper", - ) + if isinstance(scene, graph.ProblemGraph): + reduced = ReducedProblemGraph(constants=scene.constants, domain="gripper") + elif isinstance(scene, graph.SceneGraph): + reduced = ReducedSceneGraph( + constants=scene.constants, + domain="gripper", + scene=scene.scene, + ) + else: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") pred_nodes = set() for node, obj, edge in scene.edges: pred = edge.predicate - reduced_edge = graph.PlanGraphEdge(predicate=pred) + reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) if node in pred_nodes: continue elif pred == "at-robby": reduced.add_edge(ReducedNode.ROBBY, obj, reduced_edge) elif pred == "free": - reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) + reduced.add_edge(ReducedNode.FREE, obj, reduced_edge) elif pred == "ball": reduced.add_edge(ReducedNode.BALLS, obj, reduced_edge) elif pred == "gripper": @@ -165,20 +269,30 @@ def _reduce_gripper( elif pred in {"carry", "at"}: pos = edge.position other_obj, *_ = [ - v for _, v, a in scene.out_edges(node) if a.position == 1 - pos + v for v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: reduced.add_edge(obj, other_obj, reduced_edge) pred_nodes.add(node) - if validate and not rx.is_directed_acyclic_graph(reduced.graph): - raise ValueError("Scene graph is not a Directed Acyclic Graph.") + if validate: + if isinstance(reduced, ReducedProblemGraph): + init, goal = reduced.decompose() + if not rx.is_directed_acyclic_graph(init.graph): + raise ValueError("Initial scene graph is not a Directed Acyclic Graph.") + if not rx.is_directed_acyclic_graph(goal.graph): + raise ValueError("Goal scene graph is not a Directed Acyclic Graph.") + elif isinstance(reduced, ReducedSceneGraph): + if rx.is_directed_acyclic_graph(reduced.graph): + raise ValueError("Scene graph is not a Directed Acyclic Graph.") return reduced -def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: +def _inflate_blocksworld( + scene: ReducedSceneGraph | ReducedProblemGraph, +) -> graph.SceneGraph: """Respecify a blocksworld scene graph. Args: @@ -194,12 +308,13 @@ def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: if not isinstance(node.node, ReducedNode): constants.append({"name": node.node, "typing": node.typing}) - for u, v, _ in scene.edges: + for u, v, edge in scene.edges: if u.node == ReducedNode.CLEAR and v.node == ReducedNode.ARM: predicates.append( { "typing": "arm-empty", "parameters": [], + "scene": edge.scene, } ) elif u.node == ReducedNode.CLEAR: @@ -207,6 +322,7 @@ def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "clear", "parameters": [v.node], + "scene": edge.scene, } ) elif v.node == ReducedNode.TABLE: @@ -214,6 +330,7 @@ def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "on-table", "parameters": [u.node], + "scene": edge.scene, } ) elif v.node == ReducedNode.ARM: @@ -221,6 +338,7 @@ def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "holding", "parameters": [u.node], + "scene": edge.scene, } ) else: @@ -228,13 +346,29 @@ def _inflate_blocksworld(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "on", "parameters": [u.node, v.node], + "scene": edge.scene, } ) - return graph.SceneGraph(constants, predicates, domain="blocksworld") + if isinstance(scene, ReducedProblemGraph): + return graph.ProblemGraph( + constants, + [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], + [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], + domain="blocksworld", + ) + else: + return graph.SceneGraph( + constants, + predicates, + domain="blocksworld", + scene=scene.scene, + ) -def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: +def _inflate_gripper( + scene: ReducedSceneGraph | ReducedProblemGraph, +) -> graph.SceneGraph | graph.ProblemGraph: """Respecify a gripper scene graph. Args: @@ -256,13 +390,15 @@ def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "at-robby", "parameters": [v.node], + "scene": edge.scene, } ) - elif u.node == ReducedNode.CLEAR: + elif u.node == ReducedNode.FREE: predicates.append( { "typing": "free", "parameters": [v.node], + "scene": edge.scene, } ) elif u.node == ReducedNode.BALLS: @@ -270,6 +406,7 @@ def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "ball", "parameters": [v.node], + "scene": edge.scene, } ) elif u.node == ReducedNode.GRIPPERS: @@ -277,6 +414,7 @@ def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "gripper", "parameters": [v.node], + "scene": edge.scene, } ) elif u.node == ReducedNode.ROOMS: @@ -284,6 +422,7 @@ def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: { "typing": "room", "parameters": [v.node], + "scene": edge.scene, } ) else: @@ -291,14 +430,28 @@ def _inflate_gripper(scene: ReducedGraph) -> graph.SceneGraph: { "typing": edge.predicate, "parameters": [u.node, v.node], + "scene": edge.scene, } ) - return graph.SceneGraph(constants, predicates, domain="gripper") + if isinstance(scene, ReducedProblemGraph): + return graph.ProblemGraph( + constants, + [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], + [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], + domain="gripper", + ) + else: + return graph.SceneGraph( + constants, + predicates, + domain="gripper", + scene=scene.scene, + ) def _blocksworld_underspecified_blocks( - scene: ReducedGraph, + scene: ReducedSceneGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. @@ -326,16 +479,27 @@ def _blocksworld_underspecified_blocks( return top_blocks, bottom_blocks, not arm_behavior_defined -def _gripper_get_typed_objects(scene: ReducedGraph): +def _gripper_get_typed_objects( + scene: ReducedSceneGraph, +) -> dict[ReducedNode, set[graph.PlanGraphNode]]: + """Get the typed objects in a gripper scene graph. + + Args: + scene (ReducedGraph): The reduced SceneGraph of a scene. + + Returns: + dict[ReducedNode, set[graph.PlanGraphNode]]: The typed objects in the + scene graph. + """ rooms = set() balls = set() grippers = set() - for _, node, _ in scene.out_edges(ReducedNode.ROOMS): + for node, _ in scene.out_edges(ReducedNode.ROOMS): rooms.add(node) - for _, node, _ in scene.out_edges(ReducedNode.BALLS): + for node, _ in scene.out_edges(ReducedNode.BALLS): balls.add(node) - for _, node, _ in scene.out_edges(ReducedNode.GRIPPERS): + for node, _ in scene.out_edges(ReducedNode.GRIPPERS): grippers.add(node) return { @@ -346,8 +510,8 @@ def _gripper_get_typed_objects(scene: ReducedGraph): def _gripper_underspecified_blocks( - init: ReducedGraph, - goal: ReducedGraph, + init: ReducedSceneGraph, + goal: ReducedSceneGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. @@ -371,7 +535,7 @@ def _gripper_underspecified_blocks( for ball in typed[ReducedNode.BALLS]: ball_edges = [ node - for _, node, _ in goal.out_edges(ball) + for node, _ in goal.out_edges(ball) if not isinstance(node, ReducedNode) ] if not ball_edges: @@ -380,7 +544,7 @@ def _gripper_underspecified_blocks( gripper_edges = [ node for node, _ in goal.in_edges(gripper) - if node == ReducedNode.CLEAR or not isinstance(node, ReducedNode) + if node == ReducedNode.FREE or not isinstance(node, ReducedNode) ] if not gripper_edges: underspecified_grippers.add(gripper) @@ -393,7 +557,7 @@ def _gripper_underspecified_blocks( def inflate( - scene: ReducedGraph, + scene: ReducedSceneGraph | ReducedProblemGraph, domain: str | None = None, ) -> graph.SceneGraph: """Inflate a reduced scene graph to a SceneGraph. @@ -419,7 +583,7 @@ def inflate( def _detached_blocks( nodesA: set[str], nodesB: set[str], - scene: ReducedGraph, + scene: ReducedSceneGraph, ) -> tuple[set[str], set[str]]: """Finds nodes that are not connected to the rest of the scene graph. @@ -439,7 +603,11 @@ def _detached_blocks( for b in nodesB: a_index = scene.nodes.index(a) b_index = scene.nodes.index(b) - if not rx.has_path(scene.graph, a_index, b_index, as_undirected=True): + if ( + not rx.has_path(scene.graph, a_index, b_index) + and not rx.has_path(scene.graph, b_index, a_index) + and a != b + ): _nodesA.discard(a) _nodesB.discard(b) @@ -447,7 +615,7 @@ def _detached_blocks( def _fully_specify_blocksworld( - scene: ReducedGraph, + scene: ReducedSceneGraph, ) -> graph.SceneGraph: """Fully specifies a blocksworld scene graph. @@ -468,13 +636,13 @@ def _fully_specify_blocksworld( scene.add_edge( ReducedNode.CLEAR, block, - graph.PlanGraphEdge(predicate="clear"), + graph.PlanGraphEdge(predicate="clear", scene=scene.scene), ) for block in bottom_blocks_: scene.add_edge( block, ReducedNode.TABLE, - graph.PlanGraphEdge(predicate="on-table"), + graph.PlanGraphEdge(predicate="on-table", scene=scene.scene), ) # handle arm @@ -482,16 +650,16 @@ def _fully_specify_blocksworld( scene.add_edge( ReducedNode.CLEAR, ReducedNode.ARM, - graph.PlanGraphEdge(predicate="arm-empty"), + graph.PlanGraphEdge(predicate="arm-empty", scene=scene.scene), ) return scene def _fully_specify_gripper( - init: ReducedGraph, - goal: ReducedGraph, -) -> ReducedGraph: + init: ReducedSceneGraph, + goal: ReducedSceneGraph, +) -> ReducedSceneGraph: """Fully specifies a gripper scene graph. Adds any missing edges to fully specify the scene graph, without adding @@ -506,6 +674,17 @@ def _fully_specify_gripper( """ scene = copy.deepcopy(goal) + # bring "typing" predicates from init to goal + typed_objects = _gripper_get_typed_objects(init) + for typing, objects in typed_objects.items(): + for obj in objects: + edge = graph.PlanGraphEdge(predicate=typing.value, scene=graph.Scene.GOAL) + edge_ = graph.PlanGraphEdge(predicate=typing.value) + if obj in scene.nodes and not ( + scene.has_edge(typing, obj, edge) or scene.has_edge(typing, obj, edge_) + ): + scene.add_edge(typing, obj, edge) + underspecified_balls, underspecified_grippers, _ = _gripper_underspecified_blocks( init, goal, @@ -514,7 +693,7 @@ def _fully_specify_gripper( if underspecified_grippers and not underspecified_balls: for gripper in underspecified_grippers: scene.add_edge( - ReducedNode.CLEAR, gripper, graph.PlanGraphEdge(predicate="free") + ReducedNode.FREE, gripper, graph.PlanGraphEdge(predicate="free", scene=scene.scene) ) return scene @@ -523,6 +702,7 @@ def _fully_specify_gripper( def fully_specify( problem: graph.ProblemGraph, domain: str | None = None, + return_reduced: bool = False, ) -> graph.ProblemGraph: """Fully specifies a goal state. @@ -531,95 +711,41 @@ def fully_specify( fully specify. domain (str | None, optional): The domain of the scene graph. Defaults to None. + return_reduced (bool, optional): Whether to return the reduced scene + graph. Defaults to False. Returns: graph.ProblemGraph: The fully specified problem graph. """ domain = domain or problem.domain - init, goal = problem.decompose() + reduced_init, reduced_goal = reduce(problem).decompose() + match domain: case "blocksworld": - reduced_goal = _reduce_blocksworld(goal) fully_specified_goal = _fully_specify_blocksworld(reduced_goal) case "gripper": - reduced_init = _reduce_gripper(init) - reduced_goal = _reduce_gripper(goal) fully_specified_goal = _fully_specify_gripper( reduced_init, reduced_goal, ) case _: raise ValueError(f"Domain {domain} not supported.") - return graph.ProblemGraph.join( - init, - inflate(fully_specified_goal, domain=domain), - ) - - -def is_fully_specified( - problem: graph.ProblemGraph, - domain: str | None = None, - is_placeholder: bool = False, -) -> bool: - """Checks if a goal state is fully specified. - - Args: - problem (graph.ProblemGraph): The problem graph with the goal state to - evaluate. - domain (str | None, optional): The domain of the scene graph. Defaults - to None. - is_placeholder (bool, optional): Whether or not every edge must be - present. Defaults to False. - - Raises: - ValueError: If a certain domain is provided but not supported. - - Returns: - bool: True if the goal state is fully specified, False otherwise. - """ - domain = domain or problem.domain - init, goal = problem.decompose() - match domain: - case "blocksworld": - reduced_goal = _reduce_blocksworld(goal) - top, bottom, arm_empty = _blocksworld_underspecified_blocks(reduced_goal) - - if is_placeholder: - if bottom.intersection(top) and arm_empty: - return False - return (not top) or (not bottom) - else: - return (not (top | bottom)) and not arm_empty - case "gripper": - reduced_init = _reduce_gripper(init) - reduced_goal = _reduce_gripper(goal) - balls, grippers, no_at_robby = _gripper_underspecified_blocks( - reduced_init, - reduced_goal, - ) - # check number of typed objects is the same as total - goal_type_check = len(goal.constants) == len( - set().union(*_gripper_get_typed_objects(reduced_goal).values()) - ) - type_check = ( - len(init.constants) - == len(set().union(*_gripper_get_typed_objects(reduced_init).values())) - and goal_type_check - ) - if is_placeholder: - return len(balls) == 0 and not no_at_robby - else: - return len(balls or grippers) == 0 and not no_at_robby and type_check - case _: - raise ValueError(f"Domain {domain} not supported.") + if return_reduced: + return ReducedProblemGraph.join(reduced_init, fully_specified_goal) + else: + init, _ = problem.decompose() + return graph.ProblemGraph.join( + init, + inflate(fully_specified_goal, domain=domain), + ) def reduce( graph: graph.SceneGraph, domain: str | None = None, validate: bool = True, -) -> ReducedGraph: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a scene graph to a Directed Acyclic Graph. Args: diff --git a/tests/test_graph.py b/tests/test_graph.py index ffeedd7..6e5074a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -23,20 +23,20 @@ def test_constant_node_names(self, sgraph): Test if the names of constant nodes in the graph match the expected set. """ names = set(["p0", "p1", "f0", "f1", "f2", "f3"]) - assert all([(node in names) for node in sgraph.constants]) + assert all([(node.name in names) for node in sgraph.constant_nodes]) def test_constant_node_size(self, sgraph): """ Test if the number of constant nodes in the graph matches the expected count. """ - assert len(sgraph.constants) == 6 + assert len(sgraph.constant_nodes) == 6 def test_predicate_names(self, sgraph): """ Test if the names of predicate nodes in the graph match expected patterns. """ - for predicate in sgraph.predicates: - match predicate.split("-"): + for predicate in sgraph.predicate_nodes: + match predicate.node.split("-"): case ["above", _, _]: assert True case ["origin", _, _]: diff --git a/tests/test_metric.py b/tests/test_metric.py index 727d48d..f6e1d57 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -203,8 +203,6 @@ def test_map(self, problem_string, two_initial_problem_string): problem_graph = pddl.build(problem_string) problem_graph2 = pddl.build(two_initial_problem_string) - initial, goal = problem_graph.decompose() - assert metric.isomorphic(problem_graph, problem_graph) assert not metric.isomorphic(problem_graph, problem_graph2) diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 6523af2..7a976b5 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -873,74 +873,32 @@ def test_fully_specified(self, blocksworld_fully_specified): Test the fully specified blocksworld problem. """ problem = pddl.build(blocksworld_fully_specified) - assert oracle.is_fully_specified(problem, is_placeholder=True) - - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_clears(self, blocksworld_missing_clears): """ Test the fully specified blocksworld problem with missing clears. """ problem = pddl.build(blocksworld_missing_clears) - assert oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=True, - ) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_ontables(self, blocksworld_missing_ontables): """ Test the fully specified blocksworld problem with missing clears. """ problem = pddl.build(blocksworld_missing_ontables) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=False, - ) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_ontables_and_clears(self, blocksworld_underspecified): """ Test the fully specified blocksworld problem with missing clears. """ problem = pddl.build(blocksworld_underspecified) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) - - assert not oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) - assert not oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=False, - ) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_inflate( self, @@ -965,9 +923,11 @@ def test_inflate( ] for desc in descs: - init, goal = pddl.build(desc).decompose() + problem = pddl.build(desc) + init, goal = problem.decompose() assert reduce_and_inflate(init) assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) def test_invalid( self, @@ -1000,24 +960,16 @@ def test_fully_specified( Test the fully specified gripper problem. """ problem = pddl.build(gripper_fully_specified) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified(problem, is_placeholder=False) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full problem = pddl.build(gripper_no_goal_types) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - is_placeholder=True, - ) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full problem = pddl.build(gripper_fully_specified_not_strict) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - is_placeholder=True, - ) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_inflate(self, gripper_fully_specified): """ @@ -1034,12 +986,12 @@ def test_underspecified( gripper_underspecified_2, ): problem = pddl.build(gripper_underspecified_1) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full problem = pddl.build(gripper_underspecified_2) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_invalid(self, gripper_invalid): problem = pddl.build(gripper_invalid) @@ -1059,11 +1011,6 @@ def test_reduce_and_inflate(self, gripper_fully_specified): reduced = oracle.reduce(goal, domain="gripper") oracle.inflate(reduced, domain="gripper-modified") - def test_fully_specified(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) - with pytest.raises(ValueError): - oracle.is_fully_specified(problem, domain="gripper-modified") - def test_fully_specify(self, gripper_fully_specified): problem = pddl.build(gripper_fully_specified) with pytest.raises(ValueError): diff --git a/tests/test_pddl.py b/tests/test_pddl.py index f80632b..214d34c 100644 --- a/tests/test_pddl.py +++ b/tests/test_pddl.py @@ -255,6 +255,44 @@ def wrong_move_problem_string(): ) """ +@pytest.fixture +def single_predicate_goal(): + """ + Fixture providing a sample PDDL problem definition as a string. + """ + return """ + (define (problem move) + (:domain move) + (:objects a0 a1 - object + b0 b1 - room) + + (:init + (in a0 b0) + (in a1 b1)) + + (:goal (in a0 b0)) + ) + """ + +@pytest.fixture +def not_predicate_goal(): + """ + Fixture providing a sample PDDL problem definition as a string. + """ + return """ + (define (problem move) + (:domain move) + (:objects a0 a1 - object + b0 b1 - room) + + (:init + (in a0 b0) + (in a1 b1)) + + (:goal (not + (in a0 b0))) + ) + """ @pytest.fixture def problem(problem_string): @@ -370,3 +408,16 @@ def test_edge_size(self, problem_string): modified_problem_string = f"Here is an example of a problem string that is not a PDDL problem. ```pddl\n{problem_string}\n```" graph_1, graph_2 = pddl.build(modified_problem_string).decompose() assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 + + def test_single_predicate_goal(self, single_predicate_goal): + """ + Test the size of nodes in the scene graphs built from a PDDL problem. + """ + pddl.build(single_predicate_goal).decompose() + + def test_not_predicate_goal(self, not_predicate_goal): + """ + Test the size of nodes in the scene graphs built from a PDDL problem. + """ + with pytest.raises(ValueError): + pddl.build(not_predicate_goal).decompose() \ No newline at end of file From 568da2ecc46b5dfd50da3fd1bf375579abfd9039 Mon Sep 17 00:00:00 2001 From: Max Zuo Date: Sun, 23 Jun 2024 00:33:19 -0400 Subject: [PATCH 3/5] more coverage --- planetarium/oracle.py | 12 +++++++----- tests/test_oracle.py | 37 ++++++++++++++++++++++++++++++++++++- tests/test_pddl.py | 5 ++++- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/planetarium/oracle.py b/planetarium/oracle.py index 5d21d2a..e123576 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -205,9 +205,9 @@ def _validate_blocksworld(scene: graph.SceneGraph): if not rx.is_directed_acyclic_graph(scene.graph): raise ValueError("Scene graph is not a Directed Acyclic Graph.") for node in scene.nodes: - if ( - node.node != ReducedNode.TABLE and scene.in_degree(node.node) > 1 - ) or (node.node != ReducedNode.CLEAR and scene.out_degree(node.node) > 1): + if (node.node != ReducedNode.TABLE and scene.in_degree(node.node) > 1) or ( + node.node != ReducedNode.CLEAR and scene.out_degree(node.node) > 1 + ): raise ValueError( f"Node {node} has multiple parents/children. (not possible in blocksworld)." ) @@ -284,7 +284,7 @@ def _reduce_gripper( if not rx.is_directed_acyclic_graph(goal.graph): raise ValueError("Goal scene graph is not a Directed Acyclic Graph.") elif isinstance(reduced, ReducedSceneGraph): - if rx.is_directed_acyclic_graph(reduced.graph): + if not rx.is_directed_acyclic_graph(reduced.graph): raise ValueError("Scene graph is not a Directed Acyclic Graph.") return reduced @@ -693,7 +693,9 @@ def _fully_specify_gripper( if underspecified_grippers and not underspecified_balls: for gripper in underspecified_grippers: scene.add_edge( - ReducedNode.FREE, gripper, graph.PlanGraphEdge(predicate="free", scene=scene.scene) + ReducedNode.FREE, + gripper, + graph.PlanGraphEdge(predicate="free", scene=scene.scene), ) return scene diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 7a976b5..8c55a9d 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -929,6 +929,13 @@ def test_inflate( assert reduce_and_inflate(goal) assert reduce_and_inflate(problem) + assert problem == oracle.inflate( + oracle.ReducedProblemGraph.join( + oracle.reduce(init, validate=True), + oracle.reduce(goal, validate=True), + ) + ) + def test_invalid( self, blocksworld_invalid_1, @@ -940,9 +947,12 @@ def test_invalid( blocksworld_invalid_2, blocksworld_invalid_3, ): - _, goal = pddl.build(desc).decompose() + problem = pddl.build(desc) + _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) + with pytest.raises(ValueError): + oracle.reduce(problem, validate=True) class TestGripperOracle: @@ -980,6 +990,29 @@ def test_inflate(self, gripper_fully_specified): assert reduce_and_inflate(init) assert reduce_and_inflate(goal) + def test_reduce_inflate( + self, + gripper_fully_specified, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + ): + descs = [ + gripper_fully_specified, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + ] + for desc in descs: + problem = pddl.build(desc) + init, goal = problem.decompose() + + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) + def test_underspecified( self, gripper_underspecified_1, @@ -998,6 +1031,8 @@ def test_invalid(self, gripper_invalid): _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) + with pytest.raises(ValueError): + oracle.reduce(problem, validate=True) class TestUnsupportedDomain: diff --git a/tests/test_pddl.py b/tests/test_pddl.py index 214d34c..588aeb7 100644 --- a/tests/test_pddl.py +++ b/tests/test_pddl.py @@ -255,6 +255,7 @@ def wrong_move_problem_string(): ) """ + @pytest.fixture def single_predicate_goal(): """ @@ -274,6 +275,7 @@ def single_predicate_goal(): ) """ + @pytest.fixture def not_predicate_goal(): """ @@ -294,6 +296,7 @@ def not_predicate_goal(): ) """ + @pytest.fixture def problem(problem_string): """ @@ -420,4 +423,4 @@ def test_not_predicate_goal(self, not_predicate_goal): Test the size of nodes in the scene graphs built from a PDDL problem. """ with pytest.raises(ValueError): - pddl.build(not_predicate_goal).decompose() \ No newline at end of file + pddl.build(not_predicate_goal).decompose() From 12fbfd8bdd6b276aa9544b1a62f205abbd2e58c6 Mon Sep 17 00:00:00 2001 From: Max Zuo Date: Sun, 23 Jun 2024 04:02:26 -0400 Subject: [PATCH 4/5] pddl rename + plotting fix with minor fixes all over --- downward.py | 12 +- evaluate.py | 6 +- finetune.py | 61 +++++++- planetarium/{pddl.py => builder.py} | 0 planetarium/graph.py | 72 ++++++++-- utils.py | 215 +++------------------------- 6 files changed, 145 insertions(+), 221 deletions(-) rename planetarium/{pddl.py => builder.py} (100%) diff --git a/downward.py b/downward.py index bd0c513..05ef8cc 100644 --- a/downward.py +++ b/downward.py @@ -1,7 +1,5 @@ # FastDownward python wrapper -from typing import Optional, Tuple - import glob import os import re @@ -9,7 +7,7 @@ import tempfile -def _get_best_plan(plan_filepath: str) -> Tuple[str, float]: +def _get_best_plan(plan_filepath: str) -> tuple[str, float]: best_cost = float("inf") best_plan = None @@ -25,8 +23,12 @@ def _get_best_plan(plan_filepath: str) -> Tuple[str, float]: def plan( - domain: str, problem: str, downward: str = "downward", alias: str = "lama", **kwargs -) -> Tuple[Optional[str], int]: + domain: str, + problem: str, + downward: str = "downward", + alias: str = "lama", + **kwargs, +) -> tuple[str | None, int]: """Find plan using FastDownward. Args: diff --git a/evaluate.py b/evaluate.py index e89b320..ded91e0 100644 --- a/evaluate.py +++ b/evaluate.py @@ -14,7 +14,7 @@ import tqdm import torch -from planetarium import pddl, graph, metric, oracle +from planetarium import builder, graph, metric, oracle import llm_planner as llmp from utils import apply_template @@ -199,7 +199,7 @@ def result(): try: # try to parse the LLM output - llm_problem_graph = pddl.build(llm_problem_pddl) + llm_problem_graph = builder.build(llm_problem_pddl) parseable = True # reduce and further validate the LLM output @@ -207,7 +207,7 @@ def result(): oracle.reduce(llm_problem_graph.decompose()[1], validate=True) valid = True - problem_graph = pddl.build(problem_pddl) + problem_graph = builder.build(problem_pddl) init, _ = problem_graph.decompose() if len(llm_problem_graph.constants) != len(problem_graph.constants): diff --git a/finetune.py b/finetune.py index db8fa33..773142e 100644 --- a/finetune.py +++ b/finetune.py @@ -1,5 +1,7 @@ +from collections import defaultdict from functools import partial import os +import sqlite3 import yaml import dotenv @@ -10,6 +12,7 @@ from torch import nn import bitsandbytes as bnb +from datasets import Dataset from peft import LoraConfig, get_peft_model from transformers import ( AutoTokenizer, @@ -23,7 +26,7 @@ import tqdm as tqdm import llm_planner as llmp -from utils import apply_template, load_dataset, strip +from utils import apply_template from accelerate import Accelerator @@ -31,6 +34,54 @@ HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") +def load_dataset(config: dict) -> dict[str, Dataset]: + """Load the dataset from the configuration. + + Args: + config (dict): The dataset configuration. + + Returns: + dict[str, Dataset]: The loaded dataset. + """ + with open(config["splits_path"], "r") as f: + split_ids_cfg = yaml.safe_load(f) + + splits: set[str] = config.get("splits", {}).keys() + dataset = {split: defaultdict(list) for split in splits} + + # Connect to database + conn = sqlite3.connect(config["database_path"]) + c = conn.cursor() + + # load domains + domains = {} + c.execute("SELECT name, domain_pddl FROM domains") + for domain_name, domain_pddl in c.fetchall(): + domains[domain_name] = domain_pddl + + # load problems + for split in splits: + queries = [] + split_keys: list[str] = config["splits"][split] + for split_key in split_keys: + split_ids = split_ids_cfg + for key in split_key: + split_ids = split_ids[key] + + c.execute( + f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})", + split_ids, + ) + queries.extend(c.fetchall()) + + for domain, problem_pddl, natural_language in queries: + dataset[split]["domain"].append(domains[domain]) + dataset[split]["problem"].append(problem_pddl) + dataset[split]["natural_language"].append(natural_language) + + return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()} + + def find_all_linear_names( model: nn.Module, bits: int | None = None, @@ -62,6 +113,10 @@ def find_all_linear_names( return list(lora_module_names) +def strip(text: str, bos_token: str, eos_token: str) -> str: + return text.removeprefix(bos_token) + eos_token + + def preprocess( tokenizer: PreTrainedTokenizer, examples, @@ -130,7 +185,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]: ) else: bnb_config = None - + device_index = Accelerator().process_index device_map = {"": device_index} model = AutoModelForCausalLM.from_pretrained( @@ -139,7 +194,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]: token=HF_USER_TOKEN, torch_dtype=torch.bfloat16, quantization_config=bnb_config, - device_map=device_map + device_map=device_map, ) lora_config = LoraConfig( diff --git a/planetarium/pddl.py b/planetarium/builder.py similarity index 100% rename from planetarium/pddl.py rename to planetarium/builder.py diff --git a/planetarium/graph.py b/planetarium/graph.py index 0f58473..2754a27 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -2,6 +2,8 @@ import abc import enum +from functools import cached_property + import rustworkx as rx @@ -121,11 +123,11 @@ def __init__( def _node_lookup(self) -> dict[str, tuple[int, PlanGraphNode]]: return {node.node: (index, node) for index, node in enumerate(self.nodes)} - @property + @cached_property def nodes(self) -> list[PlanGraphNode]: return self.graph.nodes() - @property + @cached_property def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: return [ (self.nodes[u], self.nodes[v], data) @@ -137,6 +139,16 @@ def add_node(self, node: PlanGraphNode): raise ValueError(f"Node {node} already exists in the graph.") self.graph.add_node(node) + if node.label == Label.CONSTANT: + self.__dict__.pop("constant_nodes", None) + self.__dict__.pop("constants", None) + elif node.label == Label.PREDICATE: + self.__dict__.pop("predicate_nodes", None) + self.__dict__.pop("predicates", None) + + self.__dict__.pop("nodes", None) + self.__dict__.pop("_node_lookup", None) + def has_edge( self, u: str | PlanGraphNode, @@ -173,6 +185,9 @@ def add_edge( self.graph.add_edge(u_index, v_index, edge) + self.__dict__.pop("edges", None) + self.__dict__.pop("predicates", None) + def _add_predicate( self, predicate: dict[str, Any], @@ -277,7 +292,7 @@ def _build_unique_predicate_name( """ return "-".join([predicate_name, *argument_names]) - @property + @cached_property def domain(self) -> str | None: """ Get the domain of the scene graph. @@ -287,7 +302,7 @@ def domain(self) -> str | None: """ return self._domain - @property + @cached_property def constant_nodes(self) -> list[PlanGraphNode]: """Get a list of constant nodes in the scene graph. @@ -296,14 +311,14 @@ def constant_nodes(self) -> list[PlanGraphNode]: """ return [node for node in self.nodes if node.label == Label.CONSTANT] - @property + @cached_property def constants(self) -> list[dict[str, Any]]: return [ {"name": constant.name, "typing": constant.typing} for constant in self.constant_nodes ] - @property + @cached_property def predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the scene graph. @@ -428,7 +443,29 @@ def __eq__(self, other: "ProblemGraph") -> bool: and set(self.goal_predicate_nodes) == set(other.goal_predicate_nodes) ) - @property + def add_node(self, node: PlanGraphNode): + super().add_node(node) + if node.label == Label.PREDICATE: + self.__dict__.pop("init_predicate_nodes", None) + self.__dict__.pop("goal_predicate_nodes", None) + self.__dict__.pop("init_predicates", None) + self.__dict__.pop("goal_predicates", None) + + self.__dict__.pop("_decompose", None) + + def add_edge( + self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge + ): + super().add_edge(u, v, edge) + + self.__dict__.pop("init_predicate_nodes", None) + self.__dict__.pop("goal_predicate_nodes", None) + self.__dict__.pop("init_predicates", None) + self.__dict__.pop("goal_predicates", None) + + self.__dict__.pop("_decompose", None) + + @cached_property def init_predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the initial scene. @@ -441,7 +478,7 @@ def init_predicate_nodes(self) -> list[PlanGraphNode]: if node.label == Label.PREDICATE and node.scene == Scene.INIT ] - @property + @cached_property def goal_predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the goal scene. @@ -454,7 +491,7 @@ def goal_predicate_nodes(self) -> list[PlanGraphNode]: if node.label == Label.PREDICATE and node.scene == Scene.GOAL ] - @property + @cached_property def init_predicates(self) -> list[dict[str, Any]]: predicates = [] for node in self.init_predicate_nodes: @@ -470,7 +507,7 @@ def init_predicates(self) -> list[dict[str, Any]]: return predicates - @property + @cached_property def goal_predicates(self) -> list[dict[str, Any]]: predicates = [] for node in self.goal_predicate_nodes: @@ -486,7 +523,10 @@ def goal_predicates(self) -> list[dict[str, Any]]: return predicates - def decompose(self) -> tuple[SceneGraph, SceneGraph]: + + + @cached_property + def _decompose(self) -> tuple[SceneGraph, SceneGraph]: """ Decompose the problem graph into initial and goal scene graphs. @@ -510,6 +550,16 @@ def decompose(self) -> tuple[SceneGraph, SceneGraph]: return init_scene, goal_scene + def decompose(self) -> tuple[SceneGraph, SceneGraph]: + """ + Decompose the problem graph into initial and goal scene graphs. + + Returns: + tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. + """ + + return self._decompose + @staticmethod def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": """ diff --git a/utils.py b/utils.py index 1d6a3a1..a2d6269 100644 --- a/utils.py +++ b/utils.py @@ -1,12 +1,5 @@ -from collections import defaultdict -import sqlite3 -import yaml - -from datasets import Dataset -import matplotlib.collections import matplotlib.pyplot as plt -import numpy as np -import rustworkx as rx +import networkx as nx from planetarium import graph, oracle import llm_planner as llmp @@ -47,147 +40,6 @@ def apply_template( ) -def strip(text: str, bos_token: str, eos_token: str) -> str: - return text.removeprefix(bos_token) + eos_token - - -def _layout(G: graph.PlanGraph, scale: float = 1.0): - """Position nodes in layers of straight lines. - - Source: https://github.com/networkx/networkx/blob/main/networkx/drawing/layout.py - - Args: - G (rx.PyDiGraph): A directed graph. - scale (float, optional): Scale factor for positions. Defaults to 1. - - Returns: - dict: A dictionary of positions keyed by node. - - """ - - center = np.zeros(2) - if len(G.nodes) == 0: - return {} - - layers = rx.topological_generations(G.graph) - - pos = None - nodes = [] - width = len(layers) - for i, layer in enumerate(layers): - height = len(layer) - xs = np.repeat(i, height) - ys = np.arange(0, height, dtype=float) - offset = ((width - 1) / 2, (height - 1) / 2) - layer_pos = np.column_stack([xs, ys]) - offset - if pos is None: - pos = layer_pos - else: - pos = np.concatenate([pos, layer_pos]) - nodes.extend(layer) - - # Rescale - pos -= pos.mean(axis=0) - lim = np.abs(pos).max() - if lim > 0: - pos *= scale / lim - pos += center - # horizontal - pos = pos[:, ::-1] # swap x and y coords - pos = dict(zip(nodes, pos)) - return pos - - -def _draw( - G: graph.PlanGraph, - pos: dict, - ax: plt.Axes, - node_size: int = 300, - node_color="#1f78b4", - node_shape="o", - alpha=None, - cmap=None, - vmin=None, - vmax=None, - linewidths=None, - edgecolors=None, - label=None, - font_size: int = 12, - font_color="k", - font_family="sans-serif", - font_weight="normal", - bbox=None, - horizontalalignment="center", - verticalalignment="center", - clip_on=True, -): - """Draw the graph G with Matplotlib. - - Source: Source: https://github.com/networkx/networkx/blob/main/networkx/drawing/nx_pylab.py - """ - xy = np.asarray([pos[v] for v in G.graph.node_indices()]) - nodes_collection = ax.scatter( - xy[:, 0], - xy[:, 1], - s=node_size, - c=node_color, - marker=node_shape, - cmap=cmap, - vmin=vmin, - vmax=vmax, - alpha=alpha, - linewidths=linewidths, - edgecolors=edgecolors, - ) - nodes_collection.set_zorder(2) # nodes go on top of edges - # Add node labels: - labels = [node.node for node in G.nodes] - for n, label in enumerate(labels): - (x, y) = pos[n] - if not isinstance(label, str): - label = str(label) # this makes "1" and 1 labeled the same - ax.text( - x, - y, - label, - size=font_size, - color=font_color, - family=font_family, - weight=font_weight, - alpha=alpha, - horizontalalignment=horizontalalignment, - verticalalignment=verticalalignment, - transform=ax.transData, - bbox=bbox, - clip_on=clip_on, - ) - - # plot edges - edge_pos = np.asarray([(pos[u], pos[v]) for u, v, _ in G.graph.edge_index_map().values()]) - - edge_collection = matplotlib.collections.LineCollection( - edge_pos, - colors="k", - linewidths=1., - antialiaseds=(1,), - linestyle="solid", - alpha=alpha, - ) - - edge_collection.set_zorder(1) # edges go behind nodes - ax.add_collection(edge_collection) - - ax.tick_params( - axis="both", - which="both", - bottom=False, - left=False, - labelbottom=False, - labelleft=False, - ) - ax.set_axis_off() - - def plot(graph: graph.PlanGraph, reduce: bool = False): """Plot a graph representation of the PDDL description. @@ -198,58 +50,23 @@ def plot(graph: graph.PlanGraph, reduce: bool = False): """ if reduce: graph = oracle.reduce(graph, validate=False) - # TODO: rx has no multipartite layout - pos = _layout(graph, scale=-1) + # rx has no plotting functionality - fig = plt.figure() - _draw(graph, pos, ax=fig.gca()) + nx_graph = nx.MultiDiGraph() + nx_graph.add_edges_from([(u.node, v.node, {"data":edge}) for u, v, edge in graph.edges]) - return fig - - -def load_dataset(config: dict) -> dict[str, Dataset]: - """Load the dataset from the configuration. - - Args: - config (dict): The dataset configuration. + for layer, nodes in enumerate(nx.topological_generations(nx_graph)): + for node in nodes: + nx_graph.nodes[node]["layer"] = layer - Returns: - dict[str, Dataset]: The loaded dataset. - """ - with open(config["splits_path"], "r") as f: - split_ids_cfg = yaml.safe_load(f) - - splits: set[str] = config.get("splits", {}).keys() - dataset = {split: defaultdict(list) for split in splits} - - # Connect to database - conn = sqlite3.connect(config["database_path"]) - c = conn.cursor() - - # load domains - domains = {} - c.execute("SELECT name, domain_pddl FROM domains") - for domain_name, domain_pddl in c.fetchall(): - domains[domain_name] = domain_pddl - - # load problems - for split in splits: - queries = [] - split_keys: list[str] = config["splits"][split] - for split_key in split_keys: - split_ids = split_ids_cfg - for key in split_key: - split_ids = split_ids[key] - - c.execute( - f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})", - split_ids, - ) - queries.extend(c.fetchall()) + pos = nx.multipartite_layout( + nx_graph, + align="horizontal", + subset_key="layer", + scale=-1, + ) - for domain, problem_pddl, natural_language in queries: - dataset[split]["domain"].append(domains[domain]) - dataset[split]["problem"].append(problem_pddl) - dataset[split]["natural_language"].append(natural_language) + fig = plt.figure() + nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) - return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()} + return fig From 62e5c609dd6b2228749cc52521c1a8eef3fef7d2 Mon Sep 17 00:00:00 2001 From: Max Zuo Date: Sun, 23 Jun 2024 04:03:58 -0400 Subject: [PATCH 5/5] added tests after rename --- tests/test_graph.py | 4 ++-- tests/test_metric.py | 26 +++++++++++++------------- tests/test_oracle.py | 34 +++++++++++++++++----------------- tests/test_pddl.py | 26 +++++++++++++------------- 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 6e5074a..e202a11 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,7 +2,7 @@ from .test_pddl import problem_string -from planetarium import pddl +from planetarium import builder @pytest.fixture @@ -10,7 +10,7 @@ def sgraph(problem_string): """ Fixture providing an SGraph instance built from a PDDL problem string. """ - return pddl.build(problem_string).decompose()[0] + return builder.build(problem_string).decompose()[0] class TestGraph: diff --git a/tests/test_metric.py b/tests/test_metric.py index f6e1d57..4aa8bf2 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -1,6 +1,6 @@ import pytest -from planetarium import graph, metric, oracle, pddl +from planetarium import builder, graph, metric, oracle # pylint: disable=unused-import from .test_pddl import ( @@ -200,16 +200,16 @@ class TestMetrics: def test_map(self, problem_string, two_initial_problem_string): """Test the mapping function on graph pairs.""" - problem_graph = pddl.build(problem_string) - problem_graph2 = pddl.build(two_initial_problem_string) + problem_graph = builder.build(problem_string) + problem_graph2 = builder.build(two_initial_problem_string) assert metric.isomorphic(problem_graph, problem_graph) assert not metric.isomorphic(problem_graph, problem_graph2) def test_validate(self, problem_string, two_initial_problem_string): """Test the validation function on graph pairs.""" - problem_graph = pddl.build(problem_string) - problem_graph2 = pddl.build(two_initial_problem_string) + problem_graph = builder.build(problem_string) + problem_graph2 = builder.build(two_initial_problem_string) assert metric.equals(problem_graph, problem_graph, is_placeholder=True) assert not metric.equals( @@ -222,8 +222,8 @@ def test_swap(self, swap_problem_string, wrong_swap_problem_string): """ Test the distance function on graph pairs. """ - swap_problem = pddl.build(swap_problem_string) - wrong_swap = pddl.build(wrong_swap_problem_string) + swap_problem = builder.build(swap_problem_string) + wrong_swap = builder.build(wrong_swap_problem_string) # Test validate assert metric.equals(swap_problem, swap_problem, is_placeholder=False) @@ -234,8 +234,8 @@ def test_move(self, move_problem_string, wrong_move_problem_string): """ Test the distance function on graph pairs. """ - move_problem = pddl.build(move_problem_string) - wrong_move = pddl.build(wrong_move_problem_string) + move_problem = builder.build(move_problem_string) + wrong_move = builder.build(wrong_move_problem_string) # Test validate assert metric.equals(move_problem, move_problem, is_placeholder=True) @@ -249,10 +249,10 @@ def test_blocksworld_equivalence( blocksworld_underspecified, ): """Test the equivalence of blocksworld problems.""" - p1 = pddl.build(blocksworld_fully_specified) - p2 = pddl.build(blocksworld_missing_clears) - p3 = pddl.build(blocksworld_missing_ontables) - p4 = pddl.build(blocksworld_underspecified) + p1 = builder.build(blocksworld_fully_specified) + p2 = builder.build(blocksworld_missing_clears) + p3 = builder.build(blocksworld_missing_ontables) + p4 = builder.build(blocksworld_underspecified) p1 = oracle.fully_specify(p1) p2 = oracle.fully_specify(p2) diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 8c55a9d..2802c37 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -1,6 +1,6 @@ import pytest -from planetarium import graph, oracle, pddl +from planetarium import builder, graph, oracle @pytest.fixture @@ -872,7 +872,7 @@ def test_fully_specified(self, blocksworld_fully_specified): """ Test the fully specified blocksworld problem. """ - problem = pddl.build(blocksworld_fully_specified) + problem = builder.build(blocksworld_fully_specified) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full @@ -880,7 +880,7 @@ def test_missing_clears(self, blocksworld_missing_clears): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_missing_clears) + problem = builder.build(blocksworld_missing_clears) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full @@ -888,7 +888,7 @@ def test_missing_ontables(self, blocksworld_missing_ontables): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_missing_ontables) + problem = builder.build(blocksworld_missing_ontables) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full @@ -896,7 +896,7 @@ def test_missing_ontables_and_clears(self, blocksworld_underspecified): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_underspecified) + problem = builder.build(blocksworld_underspecified) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full @@ -923,7 +923,7 @@ def test_inflate( ] for desc in descs: - problem = pddl.build(desc) + problem = builder.build(desc) init, goal = problem.decompose() assert reduce_and_inflate(init) assert reduce_and_inflate(goal) @@ -947,7 +947,7 @@ def test_invalid( blocksworld_invalid_2, blocksworld_invalid_3, ): - problem = pddl.build(desc) + problem = builder.build(desc) _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) @@ -969,15 +969,15 @@ def test_fully_specified( """ Test the fully specified gripper problem. """ - problem = pddl.build(gripper_fully_specified) + problem = builder.build(gripper_fully_specified) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full - problem = pddl.build(gripper_no_goal_types) + problem = builder.build(gripper_no_goal_types) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full - problem = pddl.build(gripper_fully_specified_not_strict) + problem = builder.build(gripper_fully_specified_not_strict) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full @@ -986,7 +986,7 @@ def test_inflate(self, gripper_fully_specified): Test the inflate function. """ - init, goal = pddl.build(gripper_fully_specified).decompose() + init, goal = builder.build(gripper_fully_specified).decompose() assert reduce_and_inflate(init) assert reduce_and_inflate(goal) @@ -1006,7 +1006,7 @@ def test_reduce_inflate( gripper_underspecified_3, ] for desc in descs: - problem = pddl.build(desc) + problem = builder.build(desc) init, goal = problem.decompose() assert reduce_and_inflate(init) @@ -1018,16 +1018,16 @@ def test_underspecified( gripper_underspecified_1, gripper_underspecified_2, ): - problem = pddl.build(gripper_underspecified_1) + problem = builder.build(gripper_underspecified_1) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full - problem = pddl.build(gripper_underspecified_2) + problem = builder.build(gripper_underspecified_2) full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full def test_invalid(self, gripper_invalid): - problem = pddl.build(gripper_invalid) + problem = builder.build(gripper_invalid) _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) @@ -1037,7 +1037,7 @@ def test_invalid(self, gripper_invalid): class TestUnsupportedDomain: def test_reduce_and_inflate(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) + problem = builder.build(gripper_fully_specified) init, goal = problem.decompose() with pytest.raises(ValueError): @@ -1047,6 +1047,6 @@ def test_reduce_and_inflate(self, gripper_fully_specified): oracle.inflate(reduced, domain="gripper-modified") def test_fully_specify(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) + problem = builder.build(gripper_fully_specified) with pytest.raises(ValueError): oracle.fully_specify(problem, domain="gripper-modified") diff --git a/tests/test_pddl.py b/tests/test_pddl.py index 588aeb7..fa17879 100644 --- a/tests/test_pddl.py +++ b/tests/test_pddl.py @@ -1,6 +1,6 @@ import pytest -from planetarium import pddl +from planetarium import builder from pddl.parser.problem import LenientProblemParser @@ -315,14 +315,14 @@ def test_constat_name(self, problem): Test the conversion of a PDDL Constant to a dictionary with the correct name. """ constant = list(problem.objects)[0] - assert pddl._constant_to_dict(constant)["name"] == str(constant.name) + assert builder._constant_to_dict(constant)["name"] == str(constant.name) def test_constat_type(self, problem): """ Test the conversion of a PDDL Constant to a dictionary with the correct typing. """ constant = list(problem.objects)[0] - result_dict = pddl._constant_to_dict(constant) + result_dict = builder._constant_to_dict(constant) assert ( result_dict["typing"] == constant.type_tags and type(result_dict["typing"]) == set @@ -339,14 +339,14 @@ def test_predicate_name(self, problem): Test the conversion of a PDDL Predicate to a dictionary with the correct name. """ predicate = list(problem.init)[0] - assert pddl._predicate_to_dict(predicate)["typing"] == str(predicate.name) + assert builder._predicate_to_dict(predicate)["typing"] == str(predicate.name) def test_predicate_parameters(self, problem): """ Test the conversion of a PDDL Predicate to a dictionary with the correct parameters. """ predicate = list(problem.init)[0] - result_dict = pddl._predicate_to_dict(predicate) + result_dict = builder._predicate_to_dict(predicate) assert ( result_dict["parameters"] == [term.name for term in predicate.terms] and type(result_dict["parameters"]) == list @@ -362,7 +362,7 @@ def test_size(self, problem): """ Test the size of the list of constants built from a PDDL problem. """ - assert len(pddl._build_constants(problem.objects)) == len(problem.objects) + assert len(builder._build_constants(problem.objects)) == len(problem.objects) class TestBuildPredicates: @@ -374,13 +374,13 @@ def test_initial_size(self, problem): """ Test the size of the list of initial predicates built from a PDDL problem. """ - assert len(pddl._build_predicates(problem.init)) == len(problem.init) + assert len(builder._build_predicates(problem.init)) == len(problem.init) def test_goal_size(self, problem): """ Test the size of the list of goal predicates built from a PDDL problem. """ - assert len(pddl._build_predicates(problem.goal.operands)) == len( + assert len(builder._build_predicates(problem.goal.operands)) == len( problem.goal.operands ) @@ -394,14 +394,14 @@ def test_node_size(self, problem_string): """ Test the size of nodes in the scene graphs built from a PDDL problem. """ - graph_1, graph_2 = pddl.build(problem_string).decompose() + graph_1, graph_2 = builder.build(problem_string).decompose() assert len(graph_1.nodes) == 17 and len(graph_2.nodes) == 8 def test_edge_size(self, problem_string): """ Test the size of edges in the scene graphs built from a PDDL problem. """ - graph_1, graph_2 = pddl.build(problem_string).decompose() + graph_1, graph_2 = builder.build(problem_string).decompose() assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 def test_edge_size(self, problem_string): @@ -409,18 +409,18 @@ def test_edge_size(self, problem_string): Test the size of edges in the scene graphs built from a PDDL problem. """ modified_problem_string = f"Here is an example of a problem string that is not a PDDL problem. ```pddl\n{problem_string}\n```" - graph_1, graph_2 = pddl.build(modified_problem_string).decompose() + graph_1, graph_2 = builder.build(modified_problem_string).decompose() assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 def test_single_predicate_goal(self, single_predicate_goal): """ Test the size of nodes in the scene graphs built from a PDDL problem. """ - pddl.build(single_predicate_goal).decompose() + builder.build(single_predicate_goal).decompose() def test_not_predicate_goal(self, not_predicate_goal): """ Test the size of nodes in the scene graphs built from a PDDL problem. """ with pytest.raises(ValueError): - pddl.build(not_predicate_goal).decompose() + builder.build(not_predicate_goal).decompose()