diff --git a/evaluate.py b/evaluate.py index 3dc6d6b..ab42d1a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -14,7 +14,7 @@ import tqdm import torch -from planetarium import pddl, graph, metric, oracle +from planetarium import builder, graph, metric, oracle import llm_planner as llmp from utils import apply_template @@ -197,7 +197,7 @@ def result(): try: # try to parse the LLM output - llm_problem_graph = pddl.build(llm_problem_pddl) + llm_problem_graph = builder.build(llm_problem_pddl) parseable = True # reduce and further validate the LLM output @@ -205,10 +205,10 @@ def result(): oracle.reduce(llm_problem_graph.decompose()[1], validate=True) valid = True - problem_graph = pddl.build(problem_pddl) + problem_graph = builder.build(problem_pddl) init, _ = problem_graph.decompose() - if len(llm_problem_graph._constants) != len(problem_graph._constants): + if len(llm_problem_graph.constants) != len(problem_graph.constants): resolved = True return result() @@ -253,8 +253,8 @@ def full_equivalence( bool: True if the scene graphs are equivalent, False otherwise. """ return metric.equals( - oracle.fully_specify(source), - oracle.fully_specify(target), + oracle.fully_specify(source, return_reduced=True), + oracle.fully_specify(target, return_reduced=True), is_placeholder=is_placeholder, ) diff --git a/finetune.py b/finetune.py index db8fa33..773142e 100644 --- a/finetune.py +++ b/finetune.py @@ -1,5 +1,7 @@ +from collections import defaultdict from functools import partial import os +import sqlite3 import yaml import dotenv @@ -10,6 +12,7 @@ from torch import nn import bitsandbytes as bnb +from datasets import Dataset from peft import LoraConfig, get_peft_model from transformers import ( AutoTokenizer, @@ -23,7 +26,7 @@ import tqdm as tqdm import llm_planner as llmp -from utils import apply_template, load_dataset, strip +from utils import apply_template from accelerate import Accelerator @@ -31,6 +34,54 @@ HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") +def load_dataset(config: dict) -> dict[str, Dataset]: + """Load the dataset from the configuration. + + Args: + config (dict): The dataset configuration. + + Returns: + dict[str, Dataset]: The loaded dataset. + """ + with open(config["splits_path"], "r") as f: + split_ids_cfg = yaml.safe_load(f) + + splits: set[str] = config.get("splits", {}).keys() + dataset = {split: defaultdict(list) for split in splits} + + # Connect to database + conn = sqlite3.connect(config["database_path"]) + c = conn.cursor() + + # load domains + domains = {} + c.execute("SELECT name, domain_pddl FROM domains") + for domain_name, domain_pddl in c.fetchall(): + domains[domain_name] = domain_pddl + + # load problems + for split in splits: + queries = [] + split_keys: list[str] = config["splits"][split] + for split_key in split_keys: + split_ids = split_ids_cfg + for key in split_key: + split_ids = split_ids[key] + + c.execute( + f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})", + split_ids, + ) + queries.extend(c.fetchall()) + + for domain, problem_pddl, natural_language in queries: + dataset[split]["domain"].append(domains[domain]) + dataset[split]["problem"].append(problem_pddl) + dataset[split]["natural_language"].append(natural_language) + + return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()} + + def find_all_linear_names( model: nn.Module, bits: int | None = None, @@ -62,6 +113,10 @@ def find_all_linear_names( return list(lora_module_names) +def strip(text: str, bos_token: str, eos_token: str) -> str: + return text.removeprefix(bos_token) + eos_token + + def preprocess( tokenizer: PreTrainedTokenizer, examples, @@ -130,7 +185,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]: ) else: bnb_config = None - + device_index = Accelerator().process_index device_map = {"": device_index} model = AutoModelForCausalLM.from_pretrained( @@ -139,7 +194,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]: token=HF_USER_TOKEN, torch_dtype=torch.bfloat16, quantization_config=bnb_config, - device_map=device_map + device_map=device_map, ) lora_config = LoraConfig( diff --git a/planetarium/pddl.py b/planetarium/builder.py similarity index 100% rename from planetarium/pddl.py rename to planetarium/builder.py diff --git a/planetarium/graph.py b/planetarium/graph.py index 595c5a0..2754a27 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -1,7 +1,10 @@ +from typing import Any, Iterable + import abc import enum -import networkx as nx -import typing +from functools import cached_property + +import rustworkx as rx class Label(str, enum.Enum): @@ -9,9 +12,78 @@ class Label(str, enum.Enum): PREDICATE = "predicate" -class PlanGraph(nx.MultiDiGraph, metaclass=abc.ABCMeta): +class Scene(str, enum.Enum): + INIT = "init" + GOAL = "goal" + + +class PlanGraphNode: + def __init__( + self, + node: str, + name: str, + label: Label, + typing: str | None = None, + scene: Scene | None = None, + ): + self.node = node + self.name = name + self.label = label + self.typing = typing + self.scene = scene + + def __eq__(self, other: "PlanGraphNode") -> bool: + return ( + isinstance(other, PlanGraphNode) + and self.node == other.node + and self.name == other.name + and self.label == other.label + and self.typing == other.typing + and self.scene == other.scene + ) + + def __hash__(self) -> int: + return hash((self.name, self.label, (*sorted(self.typing),), self.scene)) + + def __repr__(self) -> str: + return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" + + def __str__(self) -> str: + return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" + + +class PlanGraphEdge: + def __init__( + self, + predicate: str, + position: int | None = None, + scene: Scene | None = None, + ): + self.predicate = predicate + self.position = position + self.scene = scene + + def __eq__(self, other: "PlanGraphEdge") -> bool: + return ( + isinstance(other, PlanGraphEdge) + and self.predicate == other.predicate + and self.position == other.position + and self.scene == other.scene + ) + + def __hash__(self) -> int: + return hash((self.predicate, self.position, self.scene)) + + def __repr__(self) -> str: + return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" + + def __str__(self) -> str: + return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" + + +class PlanGraph(metaclass=abc.ABCMeta): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of rx.PyDiGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -21,8 +93,7 @@ class PlanGraph(nx.MultiDiGraph, metaclass=abc.ABCMeta): def __init__( self, - constants: list[dict[str, typing.Any]], - predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], domain: str | None = None, ): """ @@ -30,64 +101,190 @@ def __init__( Parameters: constants (list): List of dictionaries representing constants. - predicates (list): List of dictionaries representing predicates. domain (str, optional): The domain of the scene graph. Defaults to None. """ - super().__init__() - self._constants = constants - self._predicates = predicates self._domain = domain + self.graph = rx.PyDiGraph() for constant in constants: self.add_node( - constant["name"], - name=constant["name"], - typing=constant["typing"], - label=Label.CONSTANT, + PlanGraphNode( + constant["name"], + name=constant["name"], + label=Label.CONSTANT, + typing=constant["typing"], + ) ) - def _add_predicate(self, predicate: dict[str, typing.Any], **kwargs): + @property + def _node_lookup(self) -> dict[str, tuple[int, PlanGraphNode]]: + return {node.node: (index, node) for index, node in enumerate(self.nodes)} + + @cached_property + def nodes(self) -> list[PlanGraphNode]: + return self.graph.nodes() + + @cached_property + def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: + return [ + (self.nodes[u], self.nodes[v], data) + for u, v, data in self.graph.edge_index_map().values() + ] + + def add_node(self, node: PlanGraphNode): + if node in self.nodes: + raise ValueError(f"Node {node} already exists in the graph.") + self.graph.add_node(node) + + if node.label == Label.CONSTANT: + self.__dict__.pop("constant_nodes", None) + self.__dict__.pop("constants", None) + elif node.label == Label.PREDICATE: + self.__dict__.pop("predicate_nodes", None) + self.__dict__.pop("predicates", None) + + self.__dict__.pop("nodes", None) + self.__dict__.pop("_node_lookup", None) + + def has_edge( + self, + u: str | PlanGraphNode, + v: str | PlanGraphNode, + edge: PlanGraphEdge | None = None, + ) -> bool: + if isinstance(u, PlanGraphNode): + u_index = self.nodes.index(u) + else: + u_index, _ = self._node_lookup[u] + + if isinstance(v, PlanGraphNode): + v_index = self.nodes.index(v) + else: + v_index, _ = self._node_lookup[v] + + if edge: + return (u_index, v_index, edge) in self.graph.edge_index_map().values() + else: + return self.graph.has_edge(u_index, v_index) + + def add_edge( + self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge + ): + if isinstance(u, PlanGraphNode): + u_index = self.nodes.index(u) + else: + u_index, _ = self._node_lookup[u] + + if isinstance(v, PlanGraphNode): + v_index = self.nodes.index(v) + else: + v_index, _ = self._node_lookup[v] + + self.graph.add_edge(u_index, v_index, edge) + + self.__dict__.pop("edges", None) + self.__dict__.pop("predicates", None) + + def _add_predicate( + self, + predicate: dict[str, Any], + scene: Scene | None = None, + ): """ Add a predicate to the plan graph. Parameters: predicate (dict): A dictionary representing the predicate. + scene (Scene, optional): The scene in which the predicate occurs. """ predicate_name = self._build_unique_predicate_name( predicate_name=predicate["typing"], argument_names=predicate["parameters"], ) self.add_node( - predicate_name, - name=predicate_name, - typing=predicate["typing"], - label=Label.PREDICATE, + PlanGraphNode( + predicate_name, + name=predicate_name, + label=Label.PREDICATE, + typing=predicate["typing"], + scene=scene, + ) ) for position, parameter_name in enumerate(predicate["parameters"]): - if parameter_name not in self.constants: + if parameter_name not in [node.name for node in self.constant_nodes]: raise ValueError(f"Parameter {parameter_name} not found in constants.") self.add_edge( predicate_name, parameter_name, - position=position, - predicate=predicate["typing"], - **kwargs, + PlanGraphEdge( + predicate=predicate["typing"], + position=position, + scene=scene, + ), ) + def in_degree(self, node: str | PlanGraphNode) -> int: + if isinstance(node, PlanGraphNode): + return self.graph.in_degree(self.nodes.index(node)) + else: + return self.graph.in_degree(self._node_lookup[node][0]) + + def out_degree(self, node: str | PlanGraphNode) -> int: + if isinstance(node, PlanGraphNode): + return self.graph.out_degree(self.nodes.index(node)) + else: + return self.graph.out_degree(self._node_lookup[node][0]) + + def predecessors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: + if isinstance(node, PlanGraphNode): + preds = self.graph.predecessors(self.nodes.index(node)) + else: + preds = self.graph.predecessors(self._node_lookup[node][0]) + + return preds + + def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: + if isinstance(node, PlanGraphNode): + succs = self.graph.successors(self.nodes.index(node)) + else: + succs = self.graph.successors(self._node_lookup[node][0]) + + return [self.nodes[succ] for succ in succs] + + def in_edges( + self, node: str | PlanGraphNode + ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: + if isinstance(node, PlanGraphNode): + edges = self.graph.in_edges(self.nodes.index(node)) + else: + edges = self.graph.in_edges(self._node_lookup[node][0]) + + return [(self.nodes[u], edge) for u, _, edge in edges] + + def out_edges( + self, node: str | PlanGraphNode + ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: + if isinstance(node, PlanGraphNode): + edges = self.graph.out_edges(self.nodes.index(node)) + else: + edges = self.graph.out_edges(self._node_lookup[node][0]) + + return [(self.nodes[v], edge) for _, v, edge in edges] + @staticmethod def _build_unique_predicate_name( - predicate_name: str, argument_names: typing.Iterable[str] + predicate_name: str, argument_names: Iterable[str] ) -> str: """ Build a unique name for a predicate based on its name and argument names. Parameters: predicate_name (str): The name of the predicate. - argument_names (typing.Iterable[str]): Sequence of argument names + argument_names (Iterable[str]): Sequence of argument names for the predicate. Returns: @@ -95,7 +292,7 @@ def _build_unique_predicate_name( """ return "-".join([predicate_name, *argument_names]) - @property + @cached_property def domain(self) -> str | None: """ Get the domain of the scene graph. @@ -105,34 +302,68 @@ def domain(self) -> str | None: """ return self._domain - @property - def constants(self) -> dict: + @cached_property + def constant_nodes(self) -> list[PlanGraphNode]: + """Get a list of constant nodes in the scene graph. + + Returns: + list[PlanGraphNode]: A list of constant nodes. """ - Get a dictionary of constant nodes in the scene graph. + return [node for node in self.nodes if node.label == Label.CONSTANT] + + @cached_property + def constants(self) -> list[dict[str, Any]]: + return [ + {"name": constant.name, "typing": constant.typing} + for constant in self.constant_nodes + ] + + @cached_property + def predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the scene graph. Returns: - dict: A dictionary containing constant nodes. + list[PlanGraphNode]: A list of predicate nodes. """ - return dict( - filter(lambda node: node[1]["label"] == Label.CONSTANT, self.nodes.items()) - ) + return [node for node in self.nodes if node.label == Label.PREDICATE] @property - def predicates(self) -> dict: + def predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } + ) + + return predicates + + def __eq__(self, other: "PlanGraph") -> bool: """ - Get a dictionary of predicate nodes in the scene graph. + Check if two plan graphs are equal. + + Parameters: + other (PlanGraph): The other plan graph to compare. Returns: - dict: A dictionary containing predicate nodes. + bool: True if the plan graphs are equal, False otherwise. """ - return dict( - filter(lambda node: node[1]["label"] == Label.PREDICATE, self.nodes.items()) + return ( + isinstance(other, PlanGraph) + and set(self.nodes) == set(other.nodes) + and set(self.edges) == set(other.edges) + and self.domain == other.domain ) class SceneGraph(PlanGraph): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of PlanGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -142,9 +373,10 @@ class SceneGraph(PlanGraph): def __init__( self, - constants: list[dict[str, typing.Any]], - predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], + predicates: list[dict[str, Any]], domain: str | None = None, + scene: Scene | None = None, ): """ Initialize the SceneGraph instance. @@ -154,17 +386,20 @@ def __init__( predicates (list): List of dictionaries representing predicates. domain (str, optional): The domain of the scene graph. Defaults to None. + scene (str, optional): The scene of the scene graph. """ - super().__init__(constants, predicates, domain=domain) + super().__init__(constants, domain=domain) + + self.scene = scene for predicate in predicates: - self._add_predicate(predicate) + self._add_predicate(predicate, scene=scene) class ProblemGraph(PlanGraph): """ - Subclass of nx.MultiDiGraph representing a scene graph. + Subclass of PlanGraph representing a scene graph. Attributes: constants (property): A dictionary of constant nodes in the scene graph. @@ -175,9 +410,9 @@ class ProblemGraph(PlanGraph): def __init__( self, - constants: list[dict[str, typing.Any]], - init_predicates: list[dict[str, typing.Any]], - goal_predicates: list[dict[str, typing.Any]], + constants: list[dict[str, Any]], + init_predicates: list[dict[str, Any]], + goal_predicates: list[dict[str, Any]], domain: str | None = None, ): """ @@ -192,72 +427,139 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. """ - - super().__init__(constants, init_predicates + goal_predicates, domain=domain) - - self._init_predicates = init_predicates - self._goal_predicates = goal_predicates + super().__init__(constants, domain=domain) for scene, predicates in ( - ("init", init_predicates), - ("goal", goal_predicates), + (Scene.INIT, init_predicates), + (Scene.GOAL, goal_predicates), ): for predicate in predicates: self._add_predicate(predicate, scene=scene) - @property - def init_predicates(self) -> dict: - """ - Get a dictionary of predicate nodes in the initial scene graph. + def __eq__(self, other: "ProblemGraph") -> bool: + return ( + super().__eq__(other) + and set(self.init_predicate_nodes) == set(other.init_predicate_nodes) + and set(self.goal_predicate_nodes) == set(other.goal_predicate_nodes) + ) + + def add_node(self, node: PlanGraphNode): + super().add_node(node) + if node.label == Label.PREDICATE: + self.__dict__.pop("init_predicate_nodes", None) + self.__dict__.pop("goal_predicate_nodes", None) + self.__dict__.pop("init_predicates", None) + self.__dict__.pop("goal_predicates", None) + + self.__dict__.pop("_decompose", None) + + def add_edge( + self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge + ): + super().add_edge(u, v, edge) + + self.__dict__.pop("init_predicate_nodes", None) + self.__dict__.pop("goal_predicate_nodes", None) + self.__dict__.pop("init_predicates", None) + self.__dict__.pop("goal_predicates", None) + + self.__dict__.pop("_decompose", None) + + @cached_property + def init_predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the initial scene. Returns: - dict: A dictionary containing predicate nodes. + list[PlanGraphNode]: A list of predicate nodes in the initial scene. """ - return dict( - filter( - lambda node: node[1]["label"] == Label.PREDICATE - and node[1]["scene"] == "init", - self.nodes.items(), - ) - ) + return [ + node + for node in self.nodes + if node.label == Label.PREDICATE and node.scene == Scene.INIT + ] - @property - def goal_predicates(self) -> dict: - """ - Get a dictionary of predicate nodes in the initial scene graph. + @cached_property + def goal_predicate_nodes(self) -> list[PlanGraphNode]: + """Get a list of predicate nodes in the goal scene. Returns: - dict: A dictionary containing predicate nodes. + list[PlanGraphNode]: A list of predicate nodes in the goal scene. """ - return dict( - filter( - lambda node: node[1]["label"] == Label.PREDICATE - and node[1]["scene"] == "goal", - self.nodes.items(), + return [ + node + for node in self.nodes + if node.label == Label.PREDICATE and node.scene == Scene.GOAL + ] + + @cached_property + def init_predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.init_predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } ) - ) - def decompose(self) -> tuple[SceneGraph, SceneGraph]: + return predicates + + @cached_property + def goal_predicates(self) -> list[dict[str, Any]]: + predicates = [] + for node in self.goal_predicate_nodes: + edges = self.out_edges(node) + edges.sort(key=lambda x: x[1].position) + predicates.append( + { + "typing": node.typing, + "parameters": [obj_node.name for obj_node, _ in edges], + "scene": node.scene, + } + ) + + return predicates + + + + @cached_property + def _decompose(self) -> tuple[SceneGraph, SceneGraph]: """ Decompose the problem graph into initial and goal scene graphs. Returns: tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. """ + init_scene = SceneGraph( - constants=self._constants, - predicates=self._init_predicates, + constants=self.constants, + predicates=self.init_predicates, domain=self.domain, + scene=Scene.INIT, ) goal_scene = SceneGraph( - constants=self._constants, - predicates=self._goal_predicates, + constants=self.constants, + predicates=self.goal_predicates, domain=self.domain, + scene=Scene.GOAL, ) return init_scene, goal_scene + def decompose(self) -> tuple[SceneGraph, SceneGraph]: + """ + Decompose the problem graph into initial and goal scene graphs. + + Returns: + tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. + """ + + return self._decompose + @staticmethod def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": """ @@ -271,8 +573,8 @@ def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": ProblemGraph: The combined problem graph. """ return ProblemGraph( - constants=init._constants, - init_predicates=init._predicates, - goal_predicates=goal._predicates, + constants=init.constants, + init_predicates=init.predicates, + goal_predicates=goal.predicates, domain=init.domain, ) diff --git a/planetarium/metric.py b/planetarium/metric.py index 1f5e506..d1201a7 100644 --- a/planetarium/metric.py +++ b/planetarium/metric.py @@ -1,206 +1,109 @@ -from typing import Any, Callable - import functools -import networkx as nx -import time - -from planetarium.graph import Label, SceneGraph, ProblemGraph +import rustworkx as rx -Node = dict[str, Any] +from planetarium import graph -def _preserves_mapping(source: Node, target: Node, mapping: dict) -> bool: +def _preserves_mapping( + source: graph.PlanGraphNode, + target: graph.PlanGraphNode, + mapping: dict, +) -> bool: """ Check if a mapping is preserved between the nodes. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. mapping (dict): The mapping between node names. Returns: bool: True if the mapping preserves names, False otherwise. """ return ( - source["label"] == Label.CONSTANT - and target["label"] == Label.CONSTANT - and mapping[source["name"]] == target["name"] + source.label == graph.Label.CONSTANT + and target.label == graph.Label.CONSTANT + and mapping[source.name] == target.name ) -def _same_typing(source: Node, target: Node) -> bool: +def _same_typing(source: graph.PlanGraphNode, target: graph.PlanGraphNode) -> bool: """ Check if the typing of two nodes is the same. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. Returns: bool: True if typings are the same, False otherwise. """ return ( - source["label"] == Label.CONSTANT - and target["label"] == Label.CONSTANT - and source["typing"] == target["typing"] + source.label == graph.Label.CONSTANT + and target.label == graph.Label.CONSTANT + and source.typing == target.typing ) -def _matching(source: Node, target: Node, mapping: dict | None) -> bool: +def _node_matching( + source: graph.PlanGraphNode, + target: graph.PlanGraphNode, + mapping: dict | None, +) -> bool: """ Check if two nodes match based on their labels, positions, and typings. Parameters: - source (Node): The source node. - target (Node): The target node. + source (graph.PlanGraphNode): The source node. + target (graph.PlanGraphNode): The target node. mapping (dict | None): The mapping between node names. Returns: bool: True if nodes match, False otherwise. """ - match (source["label"], target["label"]): - case (Label.CONSTANT, Label.CONSTANT): + match (source.label, target.label): + case (graph.Label.CONSTANT, graph.Label.CONSTANT): return _same_typing(source, target) and ( _preserves_mapping(source, target, mapping) if mapping else True ) - case (Label.PREDICATE, Label.PREDICATE): + case (graph.Label.PREDICATE, graph.Label.PREDICATE): # type of predicate should be the same as well - return source["typing"] == target["typing"] + return source.typing == target.typing case _: return False -def _map( - source: SceneGraph, - target: SceneGraph, - mapping: dict | None = None, -) -> list[dict]: - """ - Find all valid isomorphic mappings between nodes of two scene graphs. - - Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - mapping (dict | None): The initial mapping between node names. - - Returns: - list: A list of dictionaries representing valid mappings. - """ - if not nx.faster_could_be_isomorphic(source, target): - return [] - - matching = functools.partial(_matching, mapping=mapping) - mapper = nx.isomorphism.MultiDiGraphMatcher( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - ) - - if mapper.is_isomorphic(): - mapper.initialize() - return mapper.match() - else: - return [] - - -def _distance( - source: SceneGraph, target: SceneGraph, mapping: dict | None = None -) -> int: - """ - Calculate the graph edit distance between two scene graphs. - - Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - mapping (dict | None): The initial mapping between node names. - - Returns: - int: The graph edit distance. - """ - matching = functools.partial(_matching, mapping=mapping) - - return nx.graph_edit_distance( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - ) - - -def _minimal_mappings( - source: SceneGraph, - target: SceneGraph, - timeout: float | None = None, -) -> tuple[list, float, bool]: +def _edge_matching( + source: graph.PlanGraphEdge, + target: graph.PlanGraphEdge, + attributes: dict[str, str | int | graph.Scene, graph.Label] = {}, +) -> bool: """ - Calculate the graph edit distance between two scene graphs. + Check if two edges match based on their attributes. Parameters: - source (SceneGraph): The source scene graph. - target (SceneGraph): The target scene graph. - max_attempts (int): The maximum number of edit path iterations to - consider. + source (graph.PlanGraphEdge): The source edge. + target (graph.PlanGraphEdge): The target edge. + attributes (dict): The attributes to match. Returns: - tuple: - - list: A list of list of tuples representing valid mappings. - - float: The graph edit distance. - - bool: True if the timeout was reached, False otherwise. + bool: True if edges match, False otherwise. """ - start_time = time.perf_counter() - - def timed_out() -> bool: - return bool(timeout and time.perf_counter() - start_time > timeout) + def _getattr(obj, attr): + v = getattr(obj, attr, attributes[attr]) + if v is None: + v = attributes[attr] + return v - # try isomorphism first: - iso_mappings = _map(source, target) - if iso_mappings: - return [[(k, v) for k, v in m.items()] for m in iso_mappings], 0.0, False + return all(_getattr(source, attr) == _getattr(target, attr) for attr in attributes) - # if it is not isomorphic, try edit distance: - edit_path_gen = nx.similarity.optimize_edit_paths( - source, - target, - node_match=nx.isomorphism.categorical_node_match( - ["label", "typing"], - default=["", ""], - ), - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - strictly_decreasing=False, - timeout=timeout, - ) - - paths: list[Any] = [] - bestcost = float("inf") - for vertex_path, _, cost in edit_path_gen: - if bestcost != float("inf") and cost < bestcost: - paths = [] - - paths.append(vertex_path) - bestcost = cost - if timed_out(): - break - - return paths, bestcost, timed_out() - -def map( - source: ProblemGraph | SceneGraph, - target: ProblemGraph | SceneGraph, +def isomorphic( + source: graph.ProblemGraph | graph.SceneGraph, + target: graph.ProblemGraph | graph.SceneGraph, mapping: dict | None = None, - return_mappings: bool = False, -) -> list[dict] | bool: +) -> bool: """ Find all valid isomorphic mappings between nodes of two scene graphs. @@ -208,37 +111,27 @@ def map( source (ProblemGraph): The source problem graph. target (ProblemGraph): The target problem graph. mapping (dict | None): The initial mapping between node names. - return_mappings (bool): If True, the function will return a list of - dictionaries representing valid mappings. If False, the function - will return a boolean indicating if there is a valid mapping. Returns: - list | bool: A list of dictionaries representing valid mappings or a - boolean indicating if there is a valid mapping. + bool: True if there is a valid mapping, False otherwise. """ - if not nx.faster_could_be_isomorphic(source, target): - return [] if return_mappings else False - - matching = functools.partial(_matching, mapping=mapping) - mapper = nx.isomorphism.MultiDiGraphMatcher( - source, - target, - node_match=matching, - edge_match=nx.isomorphism.categorical_edge_match( - ["scene", "position", "predicate"], - default=["", -1, ""], - ), + node_matching = functools.partial(_node_matching, mapping=mapping) + edge_matching = functools.partial( + _edge_matching, + attributes={"position": -1, "predicate": "", "scene": None}, ) - if return_mappings: - return list(mapper.isomorphisms_iter()) if mapper.is_isomorphic() else [] - else: - return mapper.is_isomorphic() + return rx.is_isomorphic( + source.graph, + target.graph, + node_matcher=node_matching, + edge_matcher=edge_matching, + ) def equals( - source: ProblemGraph, - target: ProblemGraph, + source: graph.ProblemGraph, + target: graph.ProblemGraph, is_placeholder: bool = False, ) -> bool: """ @@ -255,110 +148,18 @@ def equals( Returns: bool: True if there is a valid mapping, False otherwise. """ + if source == target: + return True if not is_placeholder: - return nx.utils.graphs_equal(source, target) or bool(map(source, target)) + return isomorphic(source, target) else: source_init, source_goal = source.decompose() target_init, target_goal = target.decompose() - if nx.utils.graphs_equal( - source_init, - target_init, - ) and nx.utils.graphs_equal( - source_goal, - target_goal, - ): + if source_init == target_init and source_goal == target_goal: return True - valid_init = bool(map(source_init, target_init)) - valid_goal = bool(map(source_goal, target_goal)) + valid_init = isomorphic(source_init, target_init) + valid_goal = isomorphic(source_goal, target_goal) return valid_init and valid_goal - - -def distance( - initial_source: SceneGraph, - initial_target: SceneGraph, - goal_source: SceneGraph, - goal_target: SceneGraph, - timeout: float | None = None, -) -> tuple[float, bool, float, bool]: - """ - Calculate the graph edit distance between initial and goal scene graphs. - - Parameters: - initial_source (SceneGraph): The initial source scene graph. - initial_target (SceneGraph): The initial target scene graph. - goal_source (SceneGraph): The goal source scene graph. - goal_target (SceneGraph): The goal target scene graph. - timeout (float | None): The maximum number of seconds to spend. - - Returns: - tuple: - - float: The graph edit distance between the two initial scenes. - - bool: True if the timeout was reached, False otherwise while - calculating initial scene distance. - - float: The graph edit distance between the two goal scenes. - - bool: True if the timeout was reached, False otherwise while - calculating goal scene distance. - """ - - start_time = time.perf_counter() - - def timed_out() -> bool: - return bool(timeout and time.perf_counter() - start_time > timeout) - - def mapping_to_fn(mapping: list) -> Callable[[dict, dict], bool]: - """ - Convert a mapping to a matching function. - - Parameters: - mapping (list): The mapping between node names. - - Returns: - callable: A matching function. - """ - map_dict = {k: v for k, v in mapping} - - def matching(source: Node, target: Node) -> bool: - return ( - source["name"] not in map_dict - or map_dict.get(source["name"]) == target["name"] - ) - - return matching - - if equals( - ProblemGraph.join(initial_source, goal_source), - ProblemGraph.join(initial_target, goal_target), - is_placeholder=False, - ): - return 0.0, False, 0.0, False - - minimal_mappings, init_dist, approx_init_dist = _minimal_mappings( - initial_source, - initial_target, - timeout=timeout, - ) - - goal_dist = float("inf") - for mapping in minimal_mappings: - # use the mapping from the initial graph - edit_path_gen = nx.similarity.optimize_edit_paths( - goal_source, - goal_target, - node_match=mapping_to_fn(mapping), - edge_match=nx.isomorphism.categorical_edge_match( - ["position", "predicate"], - default=[-1, ""], - ), - timeout=timeout, - ) - - for _, _, cost in edit_path_gen: - if cost < goal_dist: - goal_dist = cost - if timed_out(): - break - - return init_dist, approx_init_dist, goal_dist, timed_out() diff --git a/planetarium/oracle.py b/planetarium/oracle.py index 4ba4e1c..e123576 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -1,33 +1,124 @@ +from typing import Any + from collections import defaultdict import copy import enum -import networkx as nx +import rustworkx as rx from planetarium import graph -class ReductionNode(tuple, enum.Enum): - TABLE = ("table", ("blocksworld",)) - CLEAR = ("clear", ("blocksworld", "gripper")) - ARM = ("arm", ("blocksworld",)) - ROOMS = ("rooms", ("gripper",)) - BALLS = ("balls", ("gripper",)) - GRIPPERS = ("grippers", ("gripper",)) - ROBBY = ("robby", ("gripper",)) +class ReducedNode(str, enum.Enum): + TABLE = "table" + CLEAR = "clear" + ARM = "arm" + ROOMS = "room" + BALLS = "ball" + GRIPPERS = "gripper" + ROBBY = "at-robby" + FREE = "free" + + +BlocksworldReducedNodes = { + ReducedNode.TABLE, + ReducedNode.CLEAR, + ReducedNode.ARM, +} + +GripperReducedNodes = { + ReducedNode.ROOMS, + ReducedNode.BALLS, + ReducedNode.GRIPPERS, + ReducedNode.ROBBY, + ReducedNode.FREE, +} + +ReducedNodes = { + "blocksworld": BlocksworldReducedNodes, + "gripper": GripperReducedNodes, +} + + +class ReducedSceneGraph(graph.PlanGraph): + def __init__( + self, + constants: list[dict[str, Any]], + domain: str, + scene: graph.Scene | None = None, + ): + super().__init__(constants, domain=domain) + self.scene = scene + + for e in ReducedNodes[domain]: + predicate = e.value + self.add_node( + graph.PlanGraphNode( + e, + name=predicate, + label=graph.Label.PREDICATE, + typing={predicate}, + ), + ) + + +class ReducedProblemGraph(graph.PlanGraph): + def __init__( + self, + constants: list[dict[str, Any]], + domain: str, + ): + super().__init__(constants, domain=domain) + + for e in ReducedNodes[domain]: + predicate = e.value + self.add_node( + graph.PlanGraphNode( + e, + name=predicate, + label=graph.Label.PREDICATE, + typing={predicate}, + ), + ) + + def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]: + init = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.INIT) + goal = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.GOAL) + + for u, v, edge in self.edges: + edge = copy.deepcopy(edge) + if edge.scene == graph.Scene.INIT: + init.add_edge(u, v, edge) + elif edge.scene == graph.Scene.GOAL: + goal.add_edge(u, v, edge) + + return init, goal + @staticmethod + def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGraph": + problem = ReducedProblemGraph(init.constants, domain=init.domain) -REDUCTION_NODES = [e.value for e in ReductionNode] + for u, v, edge in init.edges: + edge = copy.deepcopy(edge) + problem.add_edge(u, v, edge) + edge.scene = graph.Scene.INIT + for u, v, edge in goal.edges: + edge = copy.deepcopy(edge) + edge.scene = graph.Scene.GOAL + problem.add_edge(u, v, edge) + + return problem def _reduce_blocksworld( - scene: graph.SceneGraph, + scene: graph.SceneGraph | graph.ProblemGraph, validate: bool = True, -) -> tuple[nx.MultiDiGraph, nx.MultiDiGraph]: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a blocksworld scene graph to a Directed Acyclic Graph. Args: - problem (graph.SceneGraph): The scene graph to reduce. + problem (graph.SceneGraph | graph.ProblemGraph): The scene graph to + reduce. validate (bool, optional): Whether or not to validate if the reduced reprsentation is valid. Defaults to True. @@ -38,86 +129,102 @@ def _reduce_blocksworld( blocksworld) and if validate is True. Returns: - nx.MultiDiGraph: The reduced scene graph. + ReducedGraph: The reduced problem graph. """ nodes = defaultdict(list) - for node, node_attr in scene.nodes(data=True): - nodes[node_attr["label"]].append((node, node_attr)) - - reduced = nx.MultiDiGraph() - for e in ReductionNode: - predicate, r_node_domains = e.value - if "blocksworld" in r_node_domains: - reduced.add_node( - e, - name=predicate, - label=graph.Label.PREDICATE, - typing={predicate}, - ) - - for obj, attrs in nodes[graph.Label.CONSTANT]: - reduced.add_node(obj, **attrs) - if "arm-empty" in scene.predicates: - reduced.add_edge( - ReductionNode.CLEAR, - ReductionNode.ARM, - pred="arm-empty", + for node in scene.nodes: + nodes[node.label].append(node) + + if isinstance(scene, graph.ProblemGraph): + reduced = ReducedProblemGraph(constants=scene.constants, domain="blocksworld") + elif isinstance(scene, graph.SceneGraph): + reduced = ReducedSceneGraph( + constants=scene.constants, + domain="blocksworld", + scene=scene.scene, ) + else: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") + + for pred_node in scene.predicate_nodes: + if pred_node.typing == "arm-empty": + reduced.add_edge( + ReducedNode.CLEAR, + ReducedNode.ARM, + graph.PlanGraphEdge( + predicate="arm-empty", + scene=pred_node.scene, + ), + ) pred_nodes = set() - for node, obj, edge_attr in scene.edges(data=True): - pred = edge_attr["predicate"] + for node, obj, edge in scene.edges: + pred = edge.predicate + reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) if node in pred_nodes: continue - elif pred == "on-table": - reduced.add_edge(obj, ReductionNode.TABLE, pred="on-table") + if pred == "on-table": + reduced.add_edge(obj, ReducedNode.TABLE, reduced_edge) elif pred == "clear": - reduced.add_edge(ReductionNode.CLEAR, obj, pred="clear") + reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) elif pred == "on": - pos = edge_attr["position"] + pos = edge.position other_obj, *_ = [ - v - for _, v, a in scene.out_edges(node, data=True) - if a["position"] == 1 - pos + v.node for v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: - reduced.add_edge(obj, other_obj, pred="on") + reduced.add_edge(obj, other_obj, reduced_edge) elif pred == "holding": - reduced.add_edge(obj, ReductionNode.ARM, pred="holding") - + reduced.add_edge(obj, ReducedNode.ARM, reduced_edge) pred_nodes.add(node) if validate: - if not nx.is_directed_acyclic_graph(reduced): - raise ValueError("Scene graph is not a Directed Acyclic Graph.") - for node in reduced.nodes: - if (node != ReductionNode.TABLE and reduced.in_degree(node) > 1) or ( - node != ReductionNode.CLEAR and reduced.out_degree(node) > 1 - ): - raise ValueError( - f"Node {node} has multiple parents/children. (not possible in blocksworld)." - ) - if reduced.in_degree(ReductionNode.ARM) == 1: - obj = next(reduced.predecessors(ReductionNode.ARM)) - if ( - obj != ReductionNode.CLEAR - and reduced.in_degree(obj) == 1 - and next(reduced.predecessors(obj)) != ReductionNode.CLEAR - ): - raise ValueError("Object on arm is connected to another object.") - - reduced._domain = scene._domain - reduced._constants = scene._constants - reduced._predicates = scene._predicates + if isinstance(reduced, ReducedProblemGraph): + init, goal = reduced.decompose() + _validate_blocksworld(init) + _validate_blocksworld(goal) + elif isinstance(reduced, ReducedSceneGraph): + _validate_blocksworld(reduced) return reduced +def _validate_blocksworld(scene: graph.SceneGraph): + """Validates a blocksworld scene graph. + + Args: + scene (graph.SceneGraph): The scene graph to validate. + + Raises: + ValueError: If the scene graph is not a Directed Acyclic Graph. + ValueError: If a node has multiple parents/children (not allowed in + blocksworld). + ValueError: If an object on the arm is connected to another object. + """ + if not rx.is_directed_acyclic_graph(scene.graph): + raise ValueError("Scene graph is not a Directed Acyclic Graph.") + for node in scene.nodes: + if (node.node != ReducedNode.TABLE and scene.in_degree(node.node) > 1) or ( + node.node != ReducedNode.CLEAR and scene.out_degree(node.node) > 1 + ): + raise ValueError( + f"Node {node} has multiple parents/children. (not possible in blocksworld)." + ) + if scene.in_degree(ReducedNode.ARM) == 1: + obj = scene.predecessors(ReducedNode.ARM)[0] + if ( + obj.node != ReducedNode.CLEAR + and scene.in_degree(obj) == 1 + and scene.predecessors(obj)[0].node != ReducedNode.CLEAR + ): + raise ValueError("Object on arm is connected to another object.") + + def _reduce_gripper( - scene: graph.SceneGraph, + scene: graph.SceneGraph | graph.ProblemGraph, validate: bool = True, -) -> nx.MultiDiGraph: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a gripper scene graph to a Directed Acyclic Graph. Args: @@ -126,68 +233,70 @@ def _reduce_gripper( reprsentation is valid and a DAG. Defaults to True. Returns: - nx.SceneGraph: The reduced problem graph. + ReducedGraph: The reduced problem graph. """ nodes = defaultdict(list) - for node, node_attr in scene.nodes(data=True): - nodes[node_attr["label"]].append((node, node_attr)) - - reduced = nx.MultiDiGraph() - for e in ReductionNode: - predicate, r_node_domains = e.value - if "gripper" in r_node_domains: - reduced.add_node( - e, - name=predicate, - label=graph.Label.PREDICATE, - typing={predicate}, - ) - - for obj, attrs in nodes[graph.Label.CONSTANT]: - reduced.add_node(obj, **copy.deepcopy(attrs)) + for node in scene.nodes: + nodes[node.label].append(node) + + if isinstance(scene, graph.ProblemGraph): + reduced = ReducedProblemGraph(constants=scene.constants, domain="gripper") + elif isinstance(scene, graph.SceneGraph): + reduced = ReducedSceneGraph( + constants=scene.constants, + domain="gripper", + scene=scene.scene, + ) + else: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") pred_nodes = set() - for node, obj, edge_attr in scene.edges(data=True): - pred = edge_attr["predicate"] + for node, obj, edge in scene.edges: + pred = edge.predicate + reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) if node in pred_nodes: continue elif pred == "at-robby": - reduced.add_edge(ReductionNode.ROBBY, obj, pred=pred) + reduced.add_edge(ReducedNode.ROBBY, obj, reduced_edge) elif pred == "free": - reduced.add_edge(ReductionNode.CLEAR, obj, pred=pred) + reduced.add_edge(ReducedNode.FREE, obj, reduced_edge) elif pred == "ball": - reduced.add_edge(ReductionNode.BALLS, obj, pred=pred) + reduced.add_edge(ReducedNode.BALLS, obj, reduced_edge) elif pred == "gripper": - reduced.add_edge(ReductionNode.GRIPPERS, obj, pred=pred) + reduced.add_edge(ReducedNode.GRIPPERS, obj, reduced_edge) elif pred == "room": - reduced.add_edge(ReductionNode.ROOMS, obj, pred=pred) + reduced.add_edge(ReducedNode.ROOMS, obj, reduced_edge) elif pred in {"carry", "at"}: - pos = edge_attr["position"] + pos = edge.position other_obj, *_ = [ - v - for _, v, a in scene.out_edges(node, data=True) - if a["position"] == 1 - pos + v for v, a in scene.out_edges(node) if a.position == 1 - pos ] if pos == 0: - reduced.add_edge(obj, other_obj, pred=pred) + reduced.add_edge(obj, other_obj, reduced_edge) pred_nodes.add(node) - if validate and not nx.is_directed_acyclic_graph(reduced): - raise ValueError("Scene graph is not a Directed Acyclic Graph.") - - reduced._domain = scene._domain - reduced._constants = scene._constants - reduced._predicates = scene._predicates + if validate: + if isinstance(reduced, ReducedProblemGraph): + init, goal = reduced.decompose() + if not rx.is_directed_acyclic_graph(init.graph): + raise ValueError("Initial scene graph is not a Directed Acyclic Graph.") + if not rx.is_directed_acyclic_graph(goal.graph): + raise ValueError("Goal scene graph is not a Directed Acyclic Graph.") + elif isinstance(reduced, ReducedSceneGraph): + if not rx.is_directed_acyclic_graph(reduced.graph): + raise ValueError("Scene graph is not a Directed Acyclic Graph.") return reduced -def _inflate_blocksworld(scene: nx.MultiDiGraph) -> graph.SceneGraph: +def _inflate_blocksworld( + scene: ReducedSceneGraph | ReducedProblemGraph, +) -> graph.SceneGraph: """Respecify a blocksworld scene graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: graph.SceneGraph: The respecified scene graph. @@ -195,55 +304,75 @@ def _inflate_blocksworld(scene: nx.MultiDiGraph) -> graph.SceneGraph: constants = [] predicates = [] - for node, attrs in scene.nodes(data=True): - if not isinstance(node, ReductionNode): - constants.append({"name": node, "typing": attrs["typing"]}) + for node in scene.nodes: + if not isinstance(node.node, ReducedNode): + constants.append({"name": node.node, "typing": node.typing}) - for u, v, _ in scene.edges: - if u == ReductionNode.CLEAR and v == ReductionNode.ARM: + for u, v, edge in scene.edges: + if u.node == ReducedNode.CLEAR and v.node == ReducedNode.ARM: predicates.append( { "typing": "arm-empty", "parameters": [], + "scene": edge.scene, } ) - elif u == ReductionNode.CLEAR: + elif u.node == ReducedNode.CLEAR: predicates.append( { "typing": "clear", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) - elif v == ReductionNode.TABLE: + elif v.node == ReducedNode.TABLE: predicates.append( { "typing": "on-table", - "parameters": [u], + "parameters": [u.node], + "scene": edge.scene, } ) - elif v == ReductionNode.ARM: + elif v.node == ReducedNode.ARM: predicates.append( { "typing": "holding", - "parameters": [u], + "parameters": [u.node], + "scene": edge.scene, } ) else: predicates.append( { "typing": "on", - "parameters": [u, v], + "parameters": [u.node, v.node], + "scene": edge.scene, } ) - return graph.SceneGraph(constants, predicates, domain="blocksworld") + if isinstance(scene, ReducedProblemGraph): + return graph.ProblemGraph( + constants, + [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], + [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], + domain="blocksworld", + ) + else: + return graph.SceneGraph( + constants, + predicates, + domain="blocksworld", + scene=scene.scene, + ) -def _inflate_gripper(scene: nx.MultiDiGraph) -> graph.SceneGraph: +def _inflate_gripper( + scene: ReducedSceneGraph | ReducedProblemGraph, +) -> graph.SceneGraph | graph.ProblemGraph: """Respecify a gripper scene graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: graph.SceneGraph: The respecified scene graph. @@ -251,64 +380,83 @@ def _inflate_gripper(scene: nx.MultiDiGraph) -> graph.SceneGraph: constants = [] predicates = [] - for node, attrs in scene.nodes(data=True): - if not isinstance(node, ReductionNode): - constants.append({"name": node, "typing": attrs["typing"]}) + for node in scene.nodes: + if not isinstance(node.node, ReducedNode): + constants.append({"name": node.node, "typing": node.typing}) - for u, v, attr in scene.edges(data=True): - if u == ReductionNode.ROBBY: + for u, v, edge in scene.edges: + if u.node == ReducedNode.ROBBY: predicates.append( { "typing": "at-robby", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) - elif u == ReductionNode.CLEAR: + elif u.node == ReducedNode.FREE: predicates.append( { "typing": "free", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) - elif u == ReductionNode.BALLS: + elif u.node == ReducedNode.BALLS: predicates.append( { "typing": "ball", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) - elif u == ReductionNode.GRIPPERS: + elif u.node == ReducedNode.GRIPPERS: predicates.append( { "typing": "gripper", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) - elif u == ReductionNode.ROOMS: + elif u.node == ReducedNode.ROOMS: predicates.append( { "typing": "room", - "parameters": [v], + "parameters": [v.node], + "scene": edge.scene, } ) else: predicates.append( { - "typing": attr["pred"], - "parameters": [u, v], + "typing": edge.predicate, + "parameters": [u.node, v.node], + "scene": edge.scene, } ) - return graph.SceneGraph(constants, predicates, domain="gripper") + if isinstance(scene, ReducedProblemGraph): + return graph.ProblemGraph( + constants, + [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], + [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], + domain="gripper", + ) + else: + return graph.SceneGraph( + constants, + predicates, + domain="gripper", + scene=scene.scene, + ) def _blocksworld_underspecified_blocks( - scene: nx.MultiDiGraph, + scene: ReducedSceneGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: tuple[set[str], set[str], bool]: The set of blocks that are not fully @@ -318,44 +466,58 @@ def _blocksworld_underspecified_blocks( """ top_blocks = set() bottom_blocks = set() - held_block = next(scene.predecessors(ReductionNode.ARM), None) - for node, attrs in scene.nodes(data=True, default=None): - if attrs.get("label") == graph.Label.CONSTANT: + arm_behavior_defined = scene.in_degree(ReducedNode.ARM) > 0 + held_block = ( + scene.predecessors(ReducedNode.ARM)[0] if arm_behavior_defined else None + ) + for node in scene.nodes: + if node.label == graph.Label.CONSTANT: if not scene.in_edges(node) and node != held_block: top_blocks.add(node) if not scene.out_edges(node): bottom_blocks.add(node) - return top_blocks, bottom_blocks, held_block is None + return top_blocks, bottom_blocks, not arm_behavior_defined + +def _gripper_get_typed_objects( + scene: ReducedSceneGraph, +) -> dict[ReducedNode, set[graph.PlanGraphNode]]: + """Get the typed objects in a gripper scene graph. + + Args: + scene (ReducedGraph): The reduced SceneGraph of a scene. -def _gripper_get_typed_objects(scene: nx.MultiDiGraph): + Returns: + dict[ReducedNode, set[graph.PlanGraphNode]]: The typed objects in the + scene graph. + """ rooms = set() balls = set() grippers = set() - for _, node in scene.out_edges(ReductionNode.ROOMS): + for node, _ in scene.out_edges(ReducedNode.ROOMS): rooms.add(node) - for _, node in scene.out_edges(ReductionNode.BALLS): + for node, _ in scene.out_edges(ReducedNode.BALLS): balls.add(node) - for _, node in scene.out_edges(ReductionNode.GRIPPERS): + for node, _ in scene.out_edges(ReducedNode.GRIPPERS): grippers.add(node) return { - ReductionNode.ROOMS: rooms, - ReductionNode.BALLS: balls, - ReductionNode.GRIPPERS: grippers, + ReducedNode.ROOMS: rooms, + ReducedNode.BALLS: balls, + ReducedNode.GRIPPERS: grippers, } def _gripper_underspecified_blocks( - init: nx.MultiDiGraph, - goal: nx.MultiDiGraph, + init: ReducedSceneGraph, + goal: ReducedSceneGraph, ) -> tuple[set[str], set[str], bool]: """Finds blocks that are not fully specified. Args: - init (nx.MultiDiGraph): The reduced SceneGraph of the initial scene. - goal (nx.MultiDiGraph): The reduced SceneGraph of the goal scene. + init (ReducedGraph): The reduced SceneGraph of the initial scene. + goal (ReducedGraph): The reduced SceneGraph of the goal scene. Returns: tuple[set[str], set[str]]: The set of blocks that are not fully @@ -370,19 +532,19 @@ def _gripper_underspecified_blocks( underspecified_balls = set() underspecified_grippers = set() - for ball in typed[ReductionNode.BALLS]: + for ball in typed[ReducedNode.BALLS]: ball_edges = [ node - for _, node in goal.out_edges(ball) - if not isinstance(node, ReductionNode) + for node, _ in goal.out_edges(ball) + if not isinstance(node, ReducedNode) ] if not ball_edges: underspecified_balls.add(ball) - for gripper in typed[ReductionNode.GRIPPERS]: + for gripper in typed[ReducedNode.GRIPPERS]: gripper_edges = [ node for node, _ in goal.in_edges(gripper) - if node == ReductionNode.CLEAR or not isinstance(node, ReductionNode) + if node == ReducedNode.FREE or not isinstance(node, ReducedNode) ] if not gripper_edges: underspecified_grippers.add(gripper) @@ -390,18 +552,18 @@ def _gripper_underspecified_blocks( return ( underspecified_balls, underspecified_grippers, - goal.out_degree(ReductionNode.ROBBY) == 0, + goal.out_degree(ReducedNode.ROBBY) == 0, ) def inflate( - scene: nx.MultiDiGraph, + scene: ReducedSceneGraph | ReducedProblemGraph, domain: str | None = None, ) -> graph.SceneGraph: """Inflate a reduced scene graph to a SceneGraph. Args: - scene (nx.MultiDiGraph): The reduced scene graph to respecify. + scene (ReducedGraph): The reduced scene graph to respecify. domain (str | None, optional): The domain of the scene graph. Defaults to None. @@ -421,14 +583,14 @@ def inflate( def _detached_blocks( nodesA: set[str], nodesB: set[str], - scene: nx.MultiDiGraph, + scene: ReducedSceneGraph, ) -> tuple[set[str], set[str]]: """Finds nodes that are not connected to the rest of the scene graph. Args: nodesA (set[str]): The set of nodes to check. nodesB (set[str]): The set of nodes to check against. - scene (nx.MultiDiGraph): The scene graph to check against. + scene (ReducedGraph): The scene graph to check against. Returns: tuple[set[str], set[str]]: The set of nodes that are not connected to @@ -439,7 +601,13 @@ def _detached_blocks( for a in nodesA: for b in nodesB: - if not nx.has_path(scene, a, b) and not nx.has_path(scene, b, a): + a_index = scene.nodes.index(a) + b_index = scene.nodes.index(b) + if ( + not rx.has_path(scene.graph, a_index, b_index) + and not rx.has_path(scene.graph, b_index, a_index) + and a != b + ): _nodesA.discard(a) _nodesB.discard(b) @@ -447,7 +615,7 @@ def _detached_blocks( def _fully_specify_blocksworld( - scene: nx.MultiDiGraph, + scene: ReducedSceneGraph, ) -> graph.SceneGraph: """Fully specifies a blocksworld scene graph. @@ -455,7 +623,7 @@ def _fully_specify_blocksworld( edges that change the problem represented by the graph. Args: - scene (nx.MultiDiGraph): The reduced SceneGraph of a scene. + scene (ReducedGraph): The reduced SceneGraph of a scene. Returns: SceneGraph: The fully specified scene graph. @@ -465,35 +633,58 @@ def _fully_specify_blocksworld( top_blocks_, bottom_blocks_ = _detached_blocks(top_blocks, bottom_blocks, scene) for block in top_blocks_: - scene.add_edge(ReductionNode.CLEAR, block, pred="clear") + scene.add_edge( + ReducedNode.CLEAR, + block, + graph.PlanGraphEdge(predicate="clear", scene=scene.scene), + ) for block in bottom_blocks_: - scene.add_edge(block, ReductionNode.TABLE, pred="on-table") + scene.add_edge( + block, + ReducedNode.TABLE, + graph.PlanGraphEdge(predicate="on-table", scene=scene.scene), + ) # handle arm if arm_empty and not (top_blocks & bottom_blocks): - scene.add_edge(ReductionNode.CLEAR, ReductionNode.ARM, pred="arm-empty") + scene.add_edge( + ReducedNode.CLEAR, + ReducedNode.ARM, + graph.PlanGraphEdge(predicate="arm-empty", scene=scene.scene), + ) return scene def _fully_specify_gripper( - init: nx.MultiDiGraph, - goal: nx.MultiDiGraph, -) -> nx.MultiDiGraph: + init: ReducedSceneGraph, + goal: ReducedSceneGraph, +) -> ReducedSceneGraph: """Fully specifies a gripper scene graph. Adds any missing edges to fully specify the scene graph, without adding edges that change the problem represented by the graph. Args: - init (nx.MultiDiGraph): The reduced SceneGraph of the initial scene. - goal (nx.MultiDiGraph): The reduced SceneGraph of the goal scene. + init (ReducedGraph): The reduced SceneGraph of the initial scene. + goal (ReducedGraph): The reduced SceneGraph of the goal scene. Returns: SceneGraph: The fully specified scene graph. """ scene = copy.deepcopy(goal) + # bring "typing" predicates from init to goal + typed_objects = _gripper_get_typed_objects(init) + for typing, objects in typed_objects.items(): + for obj in objects: + edge = graph.PlanGraphEdge(predicate=typing.value, scene=graph.Scene.GOAL) + edge_ = graph.PlanGraphEdge(predicate=typing.value) + if obj in scene.nodes and not ( + scene.has_edge(typing, obj, edge) or scene.has_edge(typing, obj, edge_) + ): + scene.add_edge(typing, obj, edge) + underspecified_balls, underspecified_grippers, _ = _gripper_underspecified_blocks( init, goal, @@ -501,7 +692,11 @@ def _fully_specify_gripper( if underspecified_grippers and not underspecified_balls: for gripper in underspecified_grippers: - scene.add_edge(ReductionNode.CLEAR, gripper, pred="free") + scene.add_edge( + ReducedNode.FREE, + gripper, + graph.PlanGraphEdge(predicate="free", scene=scene.scene), + ) return scene @@ -509,6 +704,7 @@ def _fully_specify_gripper( def fully_specify( problem: graph.ProblemGraph, domain: str | None = None, + return_reduced: bool = False, ) -> graph.ProblemGraph: """Fully specifies a goal state. @@ -517,95 +713,41 @@ def fully_specify( fully specify. domain (str | None, optional): The domain of the scene graph. Defaults to None. + return_reduced (bool, optional): Whether to return the reduced scene + graph. Defaults to False. Returns: graph.ProblemGraph: The fully specified problem graph. """ domain = domain or problem.domain - init, goal = problem.decompose() + reduced_init, reduced_goal = reduce(problem).decompose() + match domain: case "blocksworld": - reduced_goal = _reduce_blocksworld(goal) fully_specified_goal = _fully_specify_blocksworld(reduced_goal) case "gripper": - reduced_init = _reduce_gripper(init) - reduced_goal = _reduce_gripper(goal) fully_specified_goal = _fully_specify_gripper( reduced_init, reduced_goal, ) case _: raise ValueError(f"Domain {domain} not supported.") - return graph.ProblemGraph.join( - init, - inflate(fully_specified_goal, domain=domain), - ) - - -def is_fully_specified( - problem: graph.ProblemGraph, - domain: str | None = None, - is_placeholder: bool = False, -) -> bool: - """Checks if a goal state is fully specified. - Args: - problem (graph.ProblemGraph): The problem graph with the goal state to - evaluate. - domain (str | None, optional): The domain of the scene graph. Defaults - to None. - is_placeholder (bool, optional): Whether or not every edge must be - present. Defaults to False. - - Raises: - ValueError: If a certain domain is provided but not supported. - - Returns: - bool: True if the goal state is fully specified, False otherwise. - """ - domain = domain or problem.domain - init, goal = problem.decompose() - match domain: - case "blocksworld": - reduced_goal = _reduce_blocksworld(goal) - top, bottom, arm_empty = _blocksworld_underspecified_blocks(reduced_goal) - - if is_placeholder: - if bottom.intersection(top) and arm_empty: - return False - return (not top) or (not bottom) - else: - return (not (top | bottom)) and not arm_empty - case "gripper": - reduced_init = _reduce_gripper(init) - reduced_goal = _reduce_gripper(goal) - - balls, grippers, no_at_robby = _gripper_underspecified_blocks( - reduced_init, - reduced_goal, - ) - # check number of typed objects is the same as total - goal_type_check = len(goal.constants) == len( - set().union(*_gripper_get_typed_objects(reduced_goal).values()) - ) - type_check = ( - len(init.constants) - == len(set().union(*_gripper_get_typed_objects(reduced_init).values())) - and goal_type_check - ) - if is_placeholder: - return len(balls) == 0 and not no_at_robby - else: - return len(balls or grippers) == 0 and not no_at_robby and type_check - case _: - raise ValueError(f"Domain {domain} not supported.") + if return_reduced: + return ReducedProblemGraph.join(reduced_init, fully_specified_goal) + else: + init, _ = problem.decompose() + return graph.ProblemGraph.join( + init, + inflate(fully_specified_goal, domain=domain), + ) def reduce( graph: graph.SceneGraph, domain: str | None = None, validate: bool = True, -) -> nx.MultiDiGraph: +) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a scene graph to a Directed Acyclic Graph. Args: @@ -619,7 +761,7 @@ def reduce( ValueError: If a certain domain is provided but not supported. Returns: - nx.MultiDiGraph: The reduced scene graph. + ReducedGraph: The reduced problem graph. """ domain = domain or graph.domain match domain: diff --git a/poetry.lock b/poetry.lock index 056bead..452cdf0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3472,6 +3472,84 @@ files = [ {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"}, ] +[[package]] +name = "rustworkx" +version = "0.14.2" +description = "A python graph library implemented in Rust" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rustworkx-0.14.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a28a972dc7e0faf03f9f90c5be89328af8a71e609f311840e1a6abc6385edb79"}, + {file = "rustworkx-0.14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:50e682b8fd2f11f9e99c309a01f7ed88a09ad32cda35b92c49835b1c9536ec65"}, + {file = "rustworkx-0.14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6e1c3cf3d265835429074a1ecaa8f9bff327b188e1496a120bf8be8260a46453"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a22c02f74bf391b48ae92f633083d068055f3ed85050e35fe6cda967ff8a825"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:996bad21eacbe124dd1e6abca47dd69ade9db0d4df5dd29197694f5d8e0a8258"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c4647461f05fd9f99bae52002a929e8628d4e5a2e732dbfd7abd00ae5257b7"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:829444876bba1940fa3109998f3b6c9184256d91eea5f0e09d9e9f8f26bb4704"}, + {file = "rustworkx-0.14.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:987b430dce1351a0c761bd6eedb8f6999f48983c9d4b06bf4b0b9dc45d08be8d"}, + {file = "rustworkx-0.14.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:18ef16f9b6b4f1c0d458fde3f213b78436ac810d61cae60385696b411aa80e1d"}, + {file = "rustworkx-0.14.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe30f1e22e69cbab4182d0017e21c345bf75f142a7b66a828227dd3c654d524c"}, + {file = "rustworkx-0.14.2-cp310-cp310-win32.whl", hash = "sha256:c1fe9f9ed18e270074d3632f6c70cc75c461535d9e76db39d1c0ab712bf64a7a"}, + {file = "rustworkx-0.14.2-cp310-cp310-win_amd64.whl", hash = "sha256:271b36412421d622e9e8cd27e2c6e1bd356e452f979edd41bb32d308df936f47"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c52e34ff4b08d1eaedd2ec906bca4317f4f852b36e4615d372b1ff2bb435ff26"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c97dc0cf7efef033ce50fa570887f97896b0f449c841ec3b127ecb70b3c16c84"}, + {file = "rustworkx-0.14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:114bec1606ae31c089ecf52aa511551c545c6ce0746d3e8766082ad450377a2c"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:950fa4ffc1691081587c87c4e869a8f5c7d0672d35ce1ba7c69f758f90bfe8c0"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2134aa9c2065ab6c934017b6909e224e860003eb5dbaa5d2c4e87fff1187459a"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bec8f1f1a6fed3ffbf5348a2b9d700f0b840fed2faa6a5198838d0fa9674a781"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8046991499df7aa984b3d9092e4f013597901c919aaf6fa43147e8550685734"}, + {file = "rustworkx-0.14.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acb4256fba2c4f5c4ec009f383623b6a7c0a2dbeed1b529d22a193113927364a"}, + {file = "rustworkx-0.14.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:be7be125f9313b58829f7202a66dc166b61bf3c4bbe0c509b8d6902ed0d2da45"}, + {file = "rustworkx-0.14.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5d00f87fce0e48c6d7af4b63ee635188178e91462b52ac900d36ec3184ce92fc"}, + {file = "rustworkx-0.14.2-cp311-cp311-win32.whl", hash = "sha256:521e0f432a94ac9a4c92f30a746b971f7e49476fd128d83d94d4b15a2c17245d"}, + {file = "rustworkx-0.14.2-cp311-cp311-win_amd64.whl", hash = "sha256:8fd20776c0f543340ef96450ba5d9d670b8d74396315f7191303a392844271e0"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7bb37e877653ae4b4d505fc7e5f7847ae06e6822b91cec56e9e851941a6a0ae7"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:230808e3878236464ac00001d8b440382aa6230f0073554ec627580863e380cc"}, + {file = "rustworkx-0.14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9a7cb7103ba88e12e3dd8e3b28365cbe971a8c158c1ee770646b2f3fd5cedab0"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2637d0e496f34bac45f926b0aa12fb2e143581208f29a424cfb0eb5a7b5c3bfa"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:692f78ee7f7a60d9c7082a5a26b4eefb697526f195172798389d7009510d84f3"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d5513e93b7c10fbce954771a74fc86d551eb33b9eb318eaa35d7668f9929da"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26076523a1c43e903c633f2375afac28fdbb83b9668bee00fae24d8c672bf6c9"}, + {file = "rustworkx-0.14.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6de5e2df15c415dfb6e5cb7175239d0862568cb10d028f451d358d101be5d8bf"}, + {file = "rustworkx-0.14.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:21c86c240628abc2123d7d1317647073a738bdfe143c55728261b66bc32806e2"}, + {file = "rustworkx-0.14.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0aa0277b931ca3fdfae07f8999b6a63dc9b89622b2fab820fa6bd95dd1e2e2eb"}, + {file = "rustworkx-0.14.2-cp312-cp312-win32.whl", hash = "sha256:fdc632673d4cd7f1cffe8ce13ea17dc361cf9d0d9f37dfa0888d94bdd5e6c159"}, + {file = "rustworkx-0.14.2-cp312-cp312-win_amd64.whl", hash = "sha256:47768f985f32ac1cd807af816fbd5f6e2433889793afdd838891ae516a95c8a6"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:bfeee5a5be9eb71635a7897a6d2c034b1c01bf876fd15007b8bd4c6eaa8921e2"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4a20434c77f3daab043ab2f96386b5da871ebf15a5495f9ad5b916c3edf03e5c"}, + {file = "rustworkx-0.14.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d856549e874e064af136f2ce304eb896d32d8865c3e98f8d9e83b577f4c57f1d"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2521a223fb5aab2a14351205456d02bd851e0ec6b0c028f5598fe14f292e881b"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc9e7718eee8295cd5c11a5cf1c0fc7772e9c1dcc3d110edba4c77aad47e7f07"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c23ef82b1d373e07c280b8b6927dbad3953597e34c752e14843ac3df722a621"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a46e0d1398138a75fb909369ffe6dfdcec6bab4d21794e80a9abf45fd2823f68"}, + {file = "rustworkx-0.14.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16fb941e8f48aea96ee38471a1ae770ec68623864a9b0e4760aabf82c41fc2b"}, + {file = "rustworkx-0.14.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fa92a97e5d35c6901553a812f31ca18305922c0ef06c2d7a9d20fbcc0769b4d1"}, + {file = "rustworkx-0.14.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2e14b2956f2d06f5bb196bcd95f73008245eb6ffa9ee08f86ec369acf0cc04be"}, + {file = "rustworkx-0.14.2-cp38-cp38-win32.whl", hash = "sha256:816d33f69f4189376e1bb8132dea1deef1cd019b25bd281f01b7f394fcadbdad"}, + {file = "rustworkx-0.14.2-cp38-cp38-win_amd64.whl", hash = "sha256:4163f9c2c2d2158e053b30a39f74b0382b4c5a8a43f192c13b736e200b5e2025"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a9b55a8f97799b159da96087176a0e97679dca0b6b5a14b3140aeda7e1050777"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:edb2d67870e41d5a1e16288bca0758580fb6961e8b4dfc337557bdaab81ff016"}, + {file = "rustworkx-0.14.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f058bdb50c5b0be731b96ffd789c6cec2a99e7f757a57763b2cc56004ed95af6"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff900cb6ae2d4028ffe5a3075cfefa21b14929270844b172595e6de0d2f183eb"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630177a80c68823fb2dd94733298377bd52c2ce3f66758ea0a63966fc2d7c08f"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a79c177a9e4f1c623554e01319fcb7b2a062ae26def7b85dc1f0539b7cdd874"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff956ee6c8224b8225478bb72103d4fc6dd4a247c066da30927776e1b05690"}, + {file = "rustworkx-0.14.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae61a4c58186b4e428947b92ac2aa0557bcd5071fe8102a542c4337f64091766"}, + {file = "rustworkx-0.14.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5a96b6f96e1bb4e8ee337618d8af0a1aec16c2eda6ffd9968e16d161850d1e77"}, + {file = "rustworkx-0.14.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:49bc143729e0d64a51b0ec6d745665f067116db78ce958d5cbe0389e69c6e73c"}, + {file = "rustworkx-0.14.2-cp39-cp39-win32.whl", hash = "sha256:f11e858a1804d5e276d18d6fc197f797adf5da82cd3382550abeef50196c5a7e"}, + {file = "rustworkx-0.14.2-cp39-cp39-win_amd64.whl", hash = "sha256:b55e75ea35a225d6b0afbdd449665e3b907684347be6a38648bdbfd50e177bf0"}, + {file = "rustworkx-0.14.2.tar.gz", hash = "sha256:bd649322c0649b71fa18cc70a9af027b549560415fa860d6894736029c277b13"}, +] + +[package.dependencies] +numpy = ">=1.16.0,<2" + +[package.extras] +all = ["matplotlib (>=3.0)", "pillow (>=5.4)"] +graphviz = ["pillow (>=5.4)"] +mpl = ["matplotlib (>=3.0)"] + [[package]] name = "safetensors" version = "0.4.3" @@ -4924,4 +5002,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "9967f5b885e1badb3307def97646855281bd21b9f2a3fa832b145398f77caf09" +content-hash = "b173a7cdf59110fe74010e490f44f0e9a671bc2049b55f03b413c02e47708de8" diff --git a/pyproject.toml b/pyproject.toml index 5ec07de..486d50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ datasets = "^2.20.0" peft = "^0.11.1" trl = "^0.9.4" bitsandbytes = "^0.43.1" +rustworkx = "^0.14.2" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_graph.py b/tests/test_graph.py index ffeedd7..e202a11 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,7 +2,7 @@ from .test_pddl import problem_string -from planetarium import pddl +from planetarium import builder @pytest.fixture @@ -10,7 +10,7 @@ def sgraph(problem_string): """ Fixture providing an SGraph instance built from a PDDL problem string. """ - return pddl.build(problem_string).decompose()[0] + return builder.build(problem_string).decompose()[0] class TestGraph: @@ -23,20 +23,20 @@ def test_constant_node_names(self, sgraph): Test if the names of constant nodes in the graph match the expected set. """ names = set(["p0", "p1", "f0", "f1", "f2", "f3"]) - assert all([(node in names) for node in sgraph.constants]) + assert all([(node.name in names) for node in sgraph.constant_nodes]) def test_constant_node_size(self, sgraph): """ Test if the number of constant nodes in the graph matches the expected count. """ - assert len(sgraph.constants) == 6 + assert len(sgraph.constant_nodes) == 6 def test_predicate_names(self, sgraph): """ Test if the names of predicate nodes in the graph match expected patterns. """ - for predicate in sgraph.predicates: - match predicate.split("-"): + for predicate in sgraph.predicate_nodes: + match predicate.node.split("-"): case ["above", _, _]: assert True case ["origin", _, _]: diff --git a/tests/test_metric.py b/tests/test_metric.py index 6d9de9d..4aa8bf2 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -1,6 +1,6 @@ import pytest -from planetarium import graph, metric, oracle, pddl +from planetarium import builder, graph, metric, oracle # pylint: disable=unused-import from .test_pddl import ( @@ -39,32 +39,44 @@ class TestConstantMatching: @pytest.fixture def source(self): """Fixture for a valid source constant.""" - return {"name": "o1", "typing": ["t1", "t2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "o1", "o1", typing=["t1", "t2"], label=graph.Label.CONSTANT + ) @pytest.fixture def target(self): """Fixture for a valid target constant.""" - return {"name": "c1", "typing": ["t1", "t2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "c1", "c1", typing=["t1", "t2"], label=graph.Label.CONSTANT + ) @pytest.fixture def source_incorrect_label(self): """Fixture for a source constant with an incorrect label.""" - return {"name": "o1", "typing": ["t1", "t2"], "label": graph.Label.PREDICATE} + return graph.PlanGraphNode( + "o1", "o1", typing=["t1", "t2"], label=graph.Label.PREDICATE + ) @pytest.fixture def target_incorrect_label(self): """Fixture for a target constant with an incorrect label.""" - return {"name": "c1", "typing": ["t1", "t2"], "label": graph.Label.PREDICATE} + return graph.PlanGraphNode( + "c1", "c1", typing=["t1", "t2"], label=graph.Label.PREDICATE + ) @pytest.fixture def source_incorrect_typing(self): """Fixture for a source constant with incorrect typing.""" - return {"name": "o1", "typing": ["ty1", "ty2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "o1", "o1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT + ) @pytest.fixture def target_incorrect_typing(self): """Fixture for a target constant with incorrect typing.""" - return {"name": "c1", "typing": ["ty1", "ty2"], "label": graph.Label.CONSTANT} + return graph.PlanGraphNode( + "c1", "c1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT + ) @pytest.fixture def mapping(self): @@ -78,8 +90,8 @@ def mapping_incorrect(self): def test_correct_matching(self, source, target, mapping): """Test correct matching between source and target constants.""" - assert metric._matching(source, target, None) - assert metric._matching(source, target, mapping) + assert metric._node_matching(source, target, None) + assert metric._node_matching(source, target, mapping) assert metric._same_typing(source, target) assert metric._preserves_mapping(source, target, mapping) @@ -88,10 +100,10 @@ def test_incorrect_label( self, source, target, source_incorrect_label, target_incorrect_label, mapping ): """Test incorrect label matching between source and target constants.""" - assert not metric._matching(source, target_incorrect_label, None) - assert not metric._matching(source_incorrect_label, target, None) - assert not metric._matching(source, target_incorrect_label, mapping) - assert not metric._matching(source_incorrect_label, target, mapping) + assert not metric._node_matching(source, target_incorrect_label, None) + assert not metric._node_matching(source_incorrect_label, target, None) + assert not metric._node_matching(source, target_incorrect_label, mapping) + assert not metric._node_matching(source_incorrect_label, target, mapping) assert not metric._preserves_mapping(source, target_incorrect_label, mapping) assert not metric._preserves_mapping(source_incorrect_label, target, mapping) @@ -103,10 +115,10 @@ def test_incorrect_typing( self, source, target, source_incorrect_typing, target_incorrect_typing, mapping ): """Test incorrect typing between source and target constants.""" - assert not metric._matching(source, target_incorrect_typing, None) - assert not metric._matching(source_incorrect_typing, target, None) - assert not metric._matching(source, target_incorrect_typing, mapping) - assert not metric._matching(source_incorrect_typing, target, mapping) + assert not metric._node_matching(source, target_incorrect_typing, None) + assert not metric._node_matching(source_incorrect_typing, target, None) + assert not metric._node_matching(source, target_incorrect_typing, mapping) + assert not metric._node_matching(source_incorrect_typing, target, mapping) assert metric._preserves_mapping(source, target_incorrect_typing, mapping) assert metric._preserves_mapping(source_incorrect_typing, target, mapping) @@ -116,7 +128,7 @@ def test_incorrect_typing( def test_incorrect_mapping(self, source, target, mapping_incorrect): """Test incorrect mapping between source and target constants.""" - assert not metric._matching(source, target, mapping_incorrect) + assert not metric._node_matching(source, target, mapping_incorrect) assert not metric._preserves_mapping(source, target, mapping_incorrect) @@ -128,42 +140,46 @@ class TestPredicateMatching: @pytest.fixture def source(self): """Fixture for a valid source predicate node.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.PREDICATE, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.PREDICATE, + ) @pytest.fixture def target(self): """Fixture for a valid target predicate node.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.PREDICATE, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.PREDICATE, + ) @pytest.fixture def source_incorrect_label(self): """Fixture for a source predicate node with an incorrect label.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.CONSTANT, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.CONSTANT, + ) @pytest.fixture def target_incorrect_label(self): """Fixture for a target predicate node with an incorrect label.""" - return { - "name": "f-a1-a2", - "typing": "f", - "label": graph.Label.CONSTANT, - } + return graph.PlanGraphNode( + "f-a1-a2", + "f-a1-a2", + typing="f", + label=graph.Label.CONSTANT, + ) def test_correct_matching(self, source, target): """Test correct matching between source and target predicate nodes.""" - assert metric._matching(source, target, None) + assert metric._node_matching(source, target, None) def test_incorrect_label( self, @@ -173,8 +189,8 @@ def test_incorrect_label( target_incorrect_label, ): """Test incorrect label matching between source and target predicate nodes.""" - assert not metric._matching(source, target_incorrect_label, None) - assert not metric._matching(source_incorrect_label, target, None) + assert not metric._node_matching(source, target_incorrect_label, None) + assert not metric._node_matching(source_incorrect_label, target, None) class TestMetrics: @@ -184,22 +200,16 @@ class TestMetrics: def test_map(self, problem_string, two_initial_problem_string): """Test the mapping function on graph pairs.""" - problem_graph = pddl.build(problem_string) - problem_graph2 = pddl.build(two_initial_problem_string) - - initial, goal = problem_graph.decompose() - - assert metric._map(initial, initial) != [] - assert metric._map(goal, goal) != [] - assert metric._map(initial, goal) == [] + problem_graph = builder.build(problem_string) + problem_graph2 = builder.build(two_initial_problem_string) - assert metric.map(problem_graph, problem_graph, return_mappings=True) != [] - assert metric.map(problem_graph, problem_graph2, return_mappings=True) == [] + assert metric.isomorphic(problem_graph, problem_graph) + assert not metric.isomorphic(problem_graph, problem_graph2) def test_validate(self, problem_string, two_initial_problem_string): """Test the validation function on graph pairs.""" - problem_graph = pddl.build(problem_string) - problem_graph2 = pddl.build(two_initial_problem_string) + problem_graph = builder.build(problem_string) + problem_graph2 = builder.build(two_initial_problem_string) assert metric.equals(problem_graph, problem_graph, is_placeholder=True) assert not metric.equals( @@ -208,109 +218,29 @@ def test_validate(self, problem_string, two_initial_problem_string): is_placeholder=True, ) - def test_distance_isomorphic(self, problem_string, renamed_problem_string): - """ - Test the distance function on graph pairs, considering only isomorphic cases. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(renamed_problem_string).decompose() - - # Limiting the test to isomorphic cases to optimize execution time. - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (0.0, False, 0.0, False) - - def test_distance(self, problem_string, wrong_problem_string): - """ - Test the distance function on graph pairs. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(wrong_problem_string).decompose() - - assert metric._distance(initial, initial) == 0.0 - assert metric.distance(initial, initial2, goal, goal2) == ( - 0.0, - False, - 2.0, - False, - ) - - def test_wrong_initial_scene( - self, - problem_string, - wrong_initial_problem_string, - ): - """ - Test the distance function on graph pairs. - """ - initial, goal = pddl.build(problem_string).decompose() - initial2, goal2 = pddl.build(wrong_initial_problem_string).decompose() - - # This function should timeout, so the value will be an approximation - assert metric._distance(initial, initial) == 0.0 - - init_distance, approx_init, _, _ = metric.distance( - initial, - initial2, - goal, - goal2, - timeout=2.0, - ) - - assert init_distance < 25.0 - assert approx_init - def test_swap(self, swap_problem_string, wrong_swap_problem_string): """ Test the distance function on graph pairs. """ - swap_problem = pddl.build(swap_problem_string) - initial, goal = swap_problem.decompose() - wrong_swap = pddl.build(wrong_swap_problem_string) - initial2, goal2 = wrong_swap.decompose() + swap_problem = builder.build(swap_problem_string) + wrong_swap = builder.build(wrong_swap_problem_string) # Test validate assert metric.equals(swap_problem, swap_problem, is_placeholder=False) assert not metric.equals(swap_problem, wrong_swap, is_placeholder=False) assert metric.equals(swap_problem, wrong_swap, is_placeholder=True) - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (0.0, False, 2.0, False) - def test_move(self, move_problem_string, wrong_move_problem_string): """ Test the distance function on graph pairs. """ - move_problem = pddl.build(move_problem_string) - initial, goal = move_problem.decompose() - wrong_move = pddl.build(wrong_move_problem_string) - initial2, goal2 = wrong_move.decompose() + move_problem = builder.build(move_problem_string) + wrong_move = builder.build(wrong_move_problem_string) # Test validate assert metric.equals(move_problem, move_problem, is_placeholder=True) assert not metric.equals(move_problem, wrong_move, is_placeholder=True) - # Limiting the test to isomorphic cases to optimize execution time. - assert metric._distance(initial, initial) == 0.0 - assert metric.distance( - initial, - initial2, - goal, - goal2, - timeout=15.0, - ) == (2.0, False, 0.0, False) - def test_blocksworld_equivalence( self, blocksworld_fully_specified, @@ -319,10 +249,10 @@ def test_blocksworld_equivalence( blocksworld_underspecified, ): """Test the equivalence of blocksworld problems.""" - p1 = pddl.build(blocksworld_fully_specified) - p2 = pddl.build(blocksworld_missing_clears) - p3 = pddl.build(blocksworld_missing_ontables) - p4 = pddl.build(blocksworld_underspecified) + p1 = builder.build(blocksworld_fully_specified) + p2 = builder.build(blocksworld_missing_clears) + p3 = builder.build(blocksworld_missing_ontables) + p4 = builder.build(blocksworld_underspecified) p1 = oracle.fully_specify(p1) p2 = oracle.fully_specify(p2) diff --git a/tests/test_oracle.py b/tests/test_oracle.py index b7065d8..2802c37 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -1,8 +1,6 @@ import pytest -from planetarium import graph, oracle, pddl - -import networkx as nx +from planetarium import builder, graph, oracle @pytest.fixture @@ -851,7 +849,7 @@ def gripper_missing_typing(): """ -def reduce_and_respecify(scene: graph.SceneGraph) -> bool: +def reduce_and_inflate(scene: graph.SceneGraph) -> bool: """Respecify a scene and check if it is equal to the original. Args: @@ -862,7 +860,7 @@ def reduce_and_respecify(scene: graph.SceneGraph) -> bool: """ reduced = oracle.reduce(scene, domain=scene.domain) respecified = oracle.inflate(reduced, domain=scene.domain) - return nx.utils.graphs_equal(scene, respecified) + return scene == respecified class TestBlocksworldOracle: @@ -874,77 +872,35 @@ def test_fully_specified(self, blocksworld_fully_specified): """ Test the fully specified blocksworld problem. """ - problem = pddl.build(blocksworld_fully_specified) - assert oracle.is_fully_specified(problem, is_placeholder=True) - - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) + problem = builder.build(blocksworld_fully_specified) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_clears(self, blocksworld_missing_clears): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_missing_clears) - assert oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=True, - ) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) + problem = builder.build(blocksworld_missing_clears) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_ontables(self, blocksworld_missing_ontables): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_missing_ontables) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=False, - ) + problem = builder.build(blocksworld_missing_ontables) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_missing_ontables_and_clears(self, blocksworld_underspecified): """ Test the fully specified blocksworld problem with missing clears. """ - problem = pddl.build(blocksworld_underspecified) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified( - problem, - domain="blocksworld", - is_placeholder=False, - ) - - assert not oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=True, - ) - assert not oracle.is_fully_specified( - oracle.fully_specify(problem), - domain="blocksworld", - is_placeholder=False, - ) + problem = builder.build(blocksworld_underspecified) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full - def test_respecify( + def test_inflate( self, blocksworld_fully_specified, blocksworld_missing_clears, @@ -954,7 +910,7 @@ def test_respecify( blocksworld_holding, ): """ - Test the respecify function. + Test the inflate function. """ descs = [ @@ -967,9 +923,18 @@ def test_respecify( ] for desc in descs: - init, goal = pddl.build(desc).decompose() - assert reduce_and_respecify(init) - assert reduce_and_respecify(goal) + problem = builder.build(desc) + init, goal = problem.decompose() + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) + + assert problem == oracle.inflate( + oracle.ReducedProblemGraph.join( + oracle.reduce(init, validate=True), + oracle.reduce(goal, validate=True), + ) + ) def test_invalid( self, @@ -982,9 +947,12 @@ def test_invalid( blocksworld_invalid_2, blocksworld_invalid_3, ): - _, goal = pddl.build(desc).decompose() + problem = builder.build(desc) + _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) + with pytest.raises(ValueError): + oracle.reduce(problem, validate=True) class TestGripperOracle: @@ -1001,58 +969,75 @@ def test_fully_specified( """ Test the fully specified gripper problem. """ - problem = pddl.build(gripper_fully_specified) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified(problem, is_placeholder=False) - - problem = pddl.build(gripper_no_goal_types) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - is_placeholder=True, - ) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + problem = builder.build(gripper_fully_specified) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full - problem = pddl.build(gripper_fully_specified_not_strict) - assert oracle.is_fully_specified(problem, is_placeholder=True) - assert oracle.is_fully_specified( - oracle.fully_specify(problem), - is_placeholder=True, - ) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + problem = builder.build(gripper_no_goal_types) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full + + problem = builder.build(gripper_fully_specified_not_strict) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full - def test_respecify(self, gripper_fully_specified): + def test_inflate(self, gripper_fully_specified): """ - Test the respecify function. + Test the inflate function. """ - init, goal = pddl.build(gripper_fully_specified).decompose() - assert reduce_and_respecify(init) - assert reduce_and_respecify(goal) + init, goal = builder.build(gripper_fully_specified).decompose() + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + + def test_reduce_inflate( + self, + gripper_fully_specified, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + ): + descs = [ + gripper_fully_specified, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + ] + for desc in descs: + problem = builder.build(desc) + init, goal = problem.decompose() + + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) def test_underspecified( self, gripper_underspecified_1, gripper_underspecified_2, ): - problem = pddl.build(gripper_underspecified_1) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + problem = builder.build(gripper_underspecified_1) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full - problem = pddl.build(gripper_underspecified_2) - assert not oracle.is_fully_specified(problem, is_placeholder=True) - assert not oracle.is_fully_specified(problem, is_placeholder=False) + problem = builder.build(gripper_underspecified_2) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full def test_invalid(self, gripper_invalid): - problem = pddl.build(gripper_invalid) + problem = builder.build(gripper_invalid) _, goal = problem.decompose() with pytest.raises(ValueError): oracle.reduce(goal, validate=True) + with pytest.raises(ValueError): + oracle.reduce(problem, validate=True) class TestUnsupportedDomain: - def test_reduce_and_respecify(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) + def test_reduce_and_inflate(self, gripper_fully_specified): + problem = builder.build(gripper_fully_specified) init, goal = problem.decompose() with pytest.raises(ValueError): @@ -1061,12 +1046,7 @@ def test_reduce_and_respecify(self, gripper_fully_specified): reduced = oracle.reduce(goal, domain="gripper") oracle.inflate(reduced, domain="gripper-modified") - def test_fully_specified(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) - with pytest.raises(ValueError): - oracle.is_fully_specified(problem, domain="gripper-modified") - def test_fully_specify(self, gripper_fully_specified): - problem = pddl.build(gripper_fully_specified) + problem = builder.build(gripper_fully_specified) with pytest.raises(ValueError): oracle.fully_specify(problem, domain="gripper-modified") diff --git a/tests/test_pddl.py b/tests/test_pddl.py index b55648c..fa17879 100644 --- a/tests/test_pddl.py +++ b/tests/test_pddl.py @@ -1,6 +1,6 @@ import pytest -from planetarium import pddl +from planetarium import builder from pddl.parser.problem import LenientProblemParser @@ -256,6 +256,47 @@ def wrong_move_problem_string(): """ +@pytest.fixture +def single_predicate_goal(): + """ + Fixture providing a sample PDDL problem definition as a string. + """ + return """ + (define (problem move) + (:domain move) + (:objects a0 a1 - object + b0 b1 - room) + + (:init + (in a0 b0) + (in a1 b1)) + + (:goal (in a0 b0)) + ) + """ + + +@pytest.fixture +def not_predicate_goal(): + """ + Fixture providing a sample PDDL problem definition as a string. + """ + return """ + (define (problem move) + (:domain move) + (:objects a0 a1 - object + b0 b1 - room) + + (:init + (in a0 b0) + (in a1 b1)) + + (:goal (not + (in a0 b0))) + ) + """ + + @pytest.fixture def problem(problem_string): """ @@ -274,14 +315,14 @@ def test_constat_name(self, problem): Test the conversion of a PDDL Constant to a dictionary with the correct name. """ constant = list(problem.objects)[0] - assert pddl._constant_to_dict(constant)["name"] == str(constant.name) + assert builder._constant_to_dict(constant)["name"] == str(constant.name) def test_constat_type(self, problem): """ Test the conversion of a PDDL Constant to a dictionary with the correct typing. """ constant = list(problem.objects)[0] - result_dict = pddl._constant_to_dict(constant) + result_dict = builder._constant_to_dict(constant) assert ( result_dict["typing"] == constant.type_tags and type(result_dict["typing"]) == set @@ -298,14 +339,14 @@ def test_predicate_name(self, problem): Test the conversion of a PDDL Predicate to a dictionary with the correct name. """ predicate = list(problem.init)[0] - assert pddl._predicate_to_dict(predicate)["typing"] == str(predicate.name) + assert builder._predicate_to_dict(predicate)["typing"] == str(predicate.name) def test_predicate_parameters(self, problem): """ Test the conversion of a PDDL Predicate to a dictionary with the correct parameters. """ predicate = list(problem.init)[0] - result_dict = pddl._predicate_to_dict(predicate) + result_dict = builder._predicate_to_dict(predicate) assert ( result_dict["parameters"] == [term.name for term in predicate.terms] and type(result_dict["parameters"]) == list @@ -321,7 +362,7 @@ def test_size(self, problem): """ Test the size of the list of constants built from a PDDL problem. """ - assert len(pddl._build_constants(problem.objects)) == len(problem.objects) + assert len(builder._build_constants(problem.objects)) == len(problem.objects) class TestBuildPredicates: @@ -333,13 +374,13 @@ def test_initial_size(self, problem): """ Test the size of the list of initial predicates built from a PDDL problem. """ - assert len(pddl._build_predicates(problem.init)) == len(problem.init) + assert len(builder._build_predicates(problem.init)) == len(problem.init) def test_goal_size(self, problem): """ Test the size of the list of goal predicates built from a PDDL problem. """ - assert len(pddl._build_predicates(problem.goal.operands)) == len( + assert len(builder._build_predicates(problem.goal.operands)) == len( problem.goal.operands ) @@ -353,12 +394,33 @@ def test_node_size(self, problem_string): """ Test the size of nodes in the scene graphs built from a PDDL problem. """ - graph_1, graph_2 = pddl.build(problem_string).decompose() + graph_1, graph_2 = builder.build(problem_string).decompose() assert len(graph_1.nodes) == 17 and len(graph_2.nodes) == 8 def test_edge_size(self, problem_string): """ Test the size of edges in the scene graphs built from a PDDL problem. """ - graph_1, graph_2 = pddl.build(problem_string).decompose() + graph_1, graph_2 = builder.build(problem_string).decompose() assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 + + def test_edge_size(self, problem_string): + """ + Test the size of edges in the scene graphs built from a PDDL problem. + """ + modified_problem_string = f"Here is an example of a problem string that is not a PDDL problem. ```pddl\n{problem_string}\n```" + graph_1, graph_2 = builder.build(modified_problem_string).decompose() + assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 + + def test_single_predicate_goal(self, single_predicate_goal): + """ + Test the size of nodes in the scene graphs built from a PDDL problem. + """ + builder.build(single_predicate_goal).decompose() + + def test_not_predicate_goal(self, not_predicate_goal): + """ + Test the size of nodes in the scene graphs built from a PDDL problem. + """ + with pytest.raises(ValueError): + builder.build(not_predicate_goal).decompose() diff --git a/utils.py b/utils.py index 0f80f20..a2d6269 100644 --- a/utils.py +++ b/utils.py @@ -1,8 +1,3 @@ -from collections import defaultdict -import sqlite3 -import yaml - -from datasets import Dataset import matplotlib.pyplot as plt import networkx as nx @@ -45,78 +40,33 @@ def apply_template( ) -def strip(text: str, bos_token: str, eos_token: str) -> str: - return text.removeprefix(bos_token) + eos_token - - -def plot(graph: graph.SceneGraph, already_reduced: bool = False): +def plot(graph: graph.PlanGraph, reduce: bool = False): """Plot a graph representation of the PDDL description. Args: - graph (nx.MultiDiGraph): The graph to plot. + graph (graph.PlanGraph): The graph to plot. + already_reduced (bool, optional): Whether the graph is already reduced. + Defaults to False. """ - if not already_reduced: + if reduce: graph = oracle.reduce(graph, validate=False) - for layer, nodes in enumerate(nx.topological_generations(graph)): + # rx has no plotting functionality + + nx_graph = nx.MultiDiGraph() + nx_graph.add_edges_from([(u.node, v.node, {"data":edge}) for u, v, edge in graph.edges]) + + for layer, nodes in enumerate(nx.topological_generations(nx_graph)): for node in nodes: - graph.nodes[node]["layer"] = layer + nx_graph.nodes[node]["layer"] = layer + pos = nx.multipartite_layout( - graph, + nx_graph, align="horizontal", subset_key="layer", scale=-1, ) fig = plt.figure() - - nx.draw(graph, pos=pos, ax=fig.gca(), with_labels=True) + nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) return fig - - -def load_dataset(config: dict) -> dict[str, Dataset]: - """Load the dataset from the configuration. - - Args: - config (dict): The dataset configuration. - - Returns: - dict[str, Dataset]: The loaded dataset. - """ - with open(config["splits_path"], "r") as f: - split_ids_cfg = yaml.safe_load(f) - - splits: set[str] = config.get("splits", {}).keys() - dataset: dict[str, dict[str, list]] = {split: defaultdict(list) for split in splits} - - # Connect to database - conn = sqlite3.connect(config["database_path"]) - c = conn.cursor() - - # load domains - domains = {} - c.execute("SELECT name, domain_pddl FROM domains") - for domain_name, domain_pddl in c.fetchall(): - domains[domain_name] = domain_pddl - - # load problems - for split in splits: - queries = [] - split_keys: list[str] = config["splits"][split] - for split_key in split_keys: - split_ids = split_ids_cfg - for key in split_key: - split_ids = split_ids[key] - - c.execute( - f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})", - split_ids, - ) - queries.extend(c.fetchall()) - - for domain, problem_pddl, natural_language in queries: - dataset[split]["domain"].append(domains[domain]) - dataset[split]["problem"].append(problem_pddl) - dataset[split]["natural_language"].append(natural_language) - - return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()}