Skip to content

Commit

Permalink
add evaluate and more specific error and __init__ file
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Jun 24, 2024
1 parent c415a9d commit 389c02d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
7 changes: 7 additions & 0 deletions planetarium/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = ["builder", "downward", "graph", "metric", "oracle"]

from . import builder
from . import downward
from . import graph
from . import metric
from . import oracle
64 changes: 64 additions & 0 deletions planetarium/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

from pddl.parser.problem import LenientProblemParser
from pddl.formatter import problem_to_string

from planetarium import *


VALIDATE = os.getenv("VALIDATE", "Validate")


def evaluate(
source_pddl_str: str,
target_pddl_str: str,
domain_str: str,
is_placeholder: bool = False,
) -> tuple[bool, bool, bool]:
"""Evaluate two PDDL problem descriptions for equivalence.
Args:
source_pddl_str (str):
target_pddl_str (str): The second problem PDDL string.
domain_str (str): The domain PDDL string.
is_placeholder (bool, optional): Whether or not to treat the ground truth
as a "placeholder" description. Defaults to False.
Returns:
tuple: A tuple containing the following boolean elements:
- parseable: Whether or not the PDDL string is parseable.
- solveable: Whether or not the PDDL string is solveable.
- equivalent: Whether or not the PDDL strings are equivalent.
"""
parseable = False
solveable = False
equivalent = False

source_graph = builder.build(source_pddl_str)

try:
target_graph = builder.build(target_pddl_str)
parseable = True
clean_pddl_str = problem_to_string(LenientProblemParser(target_pddl_str))

solveable = downward.validate(
builder.build(domain_str),
clean_pddl_str,
oracle.plan_to_string(oracle.plan(target_graph)),
VALIDATE,
)

if source_graph == target_graph:
equivalent = True
elif source_graph.decompose()[0] != target_graph.decompose()[0]:
equivalent = False
else:
equivalent = metric.equals(
oracle.fully_specify(source_graph, return_reduced=True),
oracle.fully_specify(target_graph, return_reduced=True),
is_placeholder=is_placeholder,
)
except Exception:
pass

return parseable, solveable, equivalent
12 changes: 8 additions & 4 deletions planetarium/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGra
return problem


class DomainNotSupportedError(Exception):
pass


def _reduce_blocksworld(
scene: graph.SceneGraph | graph.ProblemGraph,
validate: bool = True,
Expand Down Expand Up @@ -347,7 +351,7 @@ def reduce(
case "gripper":
return _reduce_gripper(graph, validate=validate)
case _:
raise ValueError(f"Domain {domain} not supported.")
raise DomainNotSupportedError(f"Domain {domain} not supported.")


def _inflate_blocksworld(
Expand Down Expand Up @@ -535,7 +539,7 @@ def inflate(
case "gripper":
return _inflate_gripper(scene)
case _:
raise ValueError(f"Domain {domain} not supported.")
raise DomainNotSupportedError(f"Domain {domain} not supported.")


def _blocksworld_underspecified_blocks(
Expand Down Expand Up @@ -795,7 +799,7 @@ def fully_specify(
reduced_goal,
)
case _:
raise ValueError(f"Domain {domain} not supported.")
raise DomainNotSupportedError(f"Domain {domain} not supported.")

if return_reduced:
return ReducedProblemGraph.join(reduced_init, fully_specified_goal)
Expand Down Expand Up @@ -954,7 +958,7 @@ def plan(problem: graph.ProblemGraph, domain: str | None = None) -> list[Action]
case "gripper":
return _plan_gripper(problem)
case _:
raise ValueError(f"Domain {domain} not supported.")
raise DomainNotSupportedError(f"Domain {domain} not supported.")


def plan_to_string(actions: list[Action]) -> str:
Expand Down

0 comments on commit 389c02d

Please sign in to comment.