diff --git a/planetarium/__init__.py b/planetarium/__init__.py index e69de29..0b019ee 100644 --- a/planetarium/__init__.py +++ b/planetarium/__init__.py @@ -0,0 +1,7 @@ +__all__ = ["builder", "downward", "graph", "metric", "oracle"] + +from . import builder +from . import downward +from . import graph +from . import metric +from . import oracle diff --git a/planetarium/evaluate.py b/planetarium/evaluate.py new file mode 100644 index 0000000..951acf1 --- /dev/null +++ b/planetarium/evaluate.py @@ -0,0 +1,64 @@ +import os + +from pddl.parser.problem import LenientProblemParser +from pddl.formatter import problem_to_string + +from planetarium import * + + +VALIDATE = os.getenv("VALIDATE", "Validate") + + +def evaluate( + source_pddl_str: str, + target_pddl_str: str, + domain_str: str, + is_placeholder: bool = False, +) -> tuple[bool, bool, bool]: + """Evaluate two PDDL problem descriptions for equivalence. + + Args: + source_pddl_str (str): + target_pddl_str (str): The second problem PDDL string. + domain_str (str): The domain PDDL string. + is_placeholder (bool, optional): Whether or not to treat the ground truth + as a "placeholder" description. Defaults to False. + + Returns: + tuple: A tuple containing the following boolean elements: + - parseable: Whether or not the PDDL string is parseable. + - solveable: Whether or not the PDDL string is solveable. + - equivalent: Whether or not the PDDL strings are equivalent. + """ + parseable = False + solveable = False + equivalent = False + + source_graph = builder.build(source_pddl_str) + + try: + target_graph = builder.build(target_pddl_str) + parseable = True + clean_pddl_str = problem_to_string(LenientProblemParser(target_pddl_str)) + + solveable = downward.validate( + builder.build(domain_str), + clean_pddl_str, + oracle.plan_to_string(oracle.plan(target_graph)), + VALIDATE, + ) + + if source_graph == target_graph: + equivalent = True + elif source_graph.decompose()[0] != target_graph.decompose()[0]: + equivalent = False + else: + equivalent = metric.equals( + oracle.fully_specify(source_graph, return_reduced=True), + oracle.fully_specify(target_graph, return_reduced=True), + is_placeholder=is_placeholder, + ) + except Exception: + pass + + return parseable, solveable, equivalent diff --git a/planetarium/oracle.py b/planetarium/oracle.py index b4c6dcd..d269d9b 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -136,6 +136,10 @@ def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGra return problem +class DomainNotSupportedError(Exception): + pass + + def _reduce_blocksworld( scene: graph.SceneGraph | graph.ProblemGraph, validate: bool = True, @@ -347,7 +351,7 @@ def reduce( case "gripper": return _reduce_gripper(graph, validate=validate) case _: - raise ValueError(f"Domain {domain} not supported.") + raise DomainNotSupportedError(f"Domain {domain} not supported.") def _inflate_blocksworld( @@ -535,7 +539,7 @@ def inflate( case "gripper": return _inflate_gripper(scene) case _: - raise ValueError(f"Domain {domain} not supported.") + raise DomainNotSupportedError(f"Domain {domain} not supported.") def _blocksworld_underspecified_blocks( @@ -795,7 +799,7 @@ def fully_specify( reduced_goal, ) case _: - raise ValueError(f"Domain {domain} not supported.") + raise DomainNotSupportedError(f"Domain {domain} not supported.") if return_reduced: return ReducedProblemGraph.join(reduced_init, fully_specified_goal) @@ -954,7 +958,7 @@ def plan(problem: graph.ProblemGraph, domain: str | None = None) -> list[Action] case "gripper": return _plan_gripper(problem) case _: - raise ValueError(f"Domain {domain} not supported.") + raise DomainNotSupportedError(f"Domain {domain} not supported.") def plan_to_string(actions: list[Action]) -> str: