Skip to content

Commit

Permalink
updated README.md + removed utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Jun 25, 2024
1 parent ac0b4ce commit f6d0d21
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 143 deletions.
42 changes: 40 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# planetarium
# planetarium🪐

Planetarium is a benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL.
Planetarium🪐 is a [dataset](https://huggingface.co/datasets/BatsResearch/planetarium) and benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. We developed a robust method for comparing PDDL problem descriptions using graph isomorphism.

## Installation
To install the `planetarium` package, you can use the following command:
Expand Down Expand Up @@ -37,3 +37,41 @@ from planetarium import evaluate
evaluate.evaluate(gt_pddl_str, pred_pddl_str)
```
The supported domains are `blocksworld` and `gripper` domains.

## Dataset
The main page for the dataset can be found [here](https://huggingface.co/datasets/BatsResearch/planetarium).

Here is an example of how to load the dataset:
```python
from datasets import load_dataset

dataset = load_dataset("BatsResearch/planetarium")
```

You can reporduce the dataset, the splits, and a report by running the following command:
```bash
python dataset_generator.py -c dataset_config.yaml
```

By modifying the `dataset_config.yaml` file, you can change the dataset splits, the number of samples, and produce even more examples!

Here is a summary of the types of PDDL problems in the dataset:

### Dataset Report
Total number of problems: $132,037$.

#### Abstractness Split
| Init | Goal | blocksworld | gripper |
|:---:|:---:|---:|---:|
| abstract | abstract | $23,144$ | $10,632$ |
| abstract | explicit | $23,086$ | $9,518$ |
| explicit | abstract | $23,087$ | $10,313$ |
| explicit | explicit | $23,033$ | $9,224$ |
#### Size Splits (Number of Propositions in Ground Truth)
| Num. of Propositions | blocksworld | gripper |
|:---:|---:|---:|
| $0$-$20$ | $1,012$ | $379$ |
| $20$-$40$ | $10,765$ | $2,112$ |
| $40$-$60$ | $50,793$ | $9,412$ |
| $60$-$80$ | $26,316$ | $25,346$ |
| $80$-inf | $3,464$ | $2,438$ |
11 changes: 3 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from planetarium import builder, graph, metric, oracle
import llm_planner as llmp

from utils import apply_template

HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")


Expand Down Expand Up @@ -82,26 +80,23 @@ def plan(
context = []
for example_problem in example_problems:
context.extend(
apply_template(
example_problem,
example_problem.apply_template(
domain_prompt,
problem_prompt,
)
)

if isinstance(problem, llmp.PlanningProblem):
messages = [
apply_template(
problem,
problem.apply_template(
domain_prompt,
problem_prompt,
include_answer=False,
)
]
else:
messages = [
apply_template(
p,
p.apply_template(
domain_prompt,
problem_prompt,
include_answer=False,
Expand Down
4 changes: 1 addition & 3 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tqdm as tqdm

import llm_planner as llmp
from utils import apply_template

from accelerate import Accelerator

Expand Down Expand Up @@ -137,8 +136,7 @@ def preprocess(
inputs = [
strip(
tokenizer.apply_chat_template(
apply_template(
llmp.PlanningProblem(nl, d, p),
llmp.PlanningProblem(nl, d, p).apply_template(
domain_prompt,
problem_prompt,
),
Expand Down
33 changes: 33 additions & 0 deletions llm_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ def __init__(
self.domain = domain
self.problem = problem

def apply_template(
self,
domain_prompt: str = "",
problem_prompt: str = "",
include_answer: bool = True,
) -> list[dict[str, str]]:
"""Apply problem template to the problem.
Args:
domain_prompt (str, optional): How to prompt the domain. Defaults to "".
problem_prompt (str, optional): How to prompt the problem. Defaults to "".
include_answer (bool, optional): Whether to include the answer. Defaults to True.
Returns:
list[dict[str, str]]: Problem prompt.
"""
return [
{
"role": "user",
"content": f"{problem_prompt} {self.natural_language} "
+ f"{domain_prompt}\n{self.domain}\n",
},
] + (
[
{
"role": "assistant",
"content": " " + self.problem,
},
]
if include_answer
else []
)


class Planner(abc.ABC):
@abc.abstractmethod
Expand Down
38 changes: 36 additions & 2 deletions planetarium/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import enum
from functools import cached_property

import matplotlib.pyplot as plt
import networkx as nx
import rustworkx as rx


Expand Down Expand Up @@ -360,6 +362,40 @@ def __eq__(self, other: "PlanGraph") -> bool:
and self.domain == other.domain
)

def plot(self, fig: plt.Figure | None = None) -> plt.Figure:
"""Generate a plot of the graph, sorted by topological generation.
Args:
fig (plt.Figure | None, optional): The figure to plot on. Defaults
to None.
Returns:
plt.Figure: The figure containing the plot.
"""
# rx has no plotting functionality
nx_graph = nx.MultiDiGraph()
nx_graph.add_edges_from(
[(u.node, v.node, {"data": edge}) for u, v, edge in self.edges]
)

for layer, nodes in enumerate(nx.topological_generations(nx_graph)):
for node in nodes:
nx_graph.nodes[node]["layer"] = layer

pos = nx.multipartite_layout(
nx_graph,
align="horizontal",
subset_key="layer",
scale=-1,
)

if fig is None:
fig = plt.figure()

nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True)

return fig


class SceneGraph(PlanGraph):
"""
Expand Down Expand Up @@ -523,8 +559,6 @@ def goal_predicates(self) -> list[dict[str, Any]]:

return predicates



@cached_property
def _decompose(self) -> tuple[SceneGraph, SceneGraph]:
"""
Expand Down
Loading

0 comments on commit f6d0d21

Please sign in to comment.