diff --git a/README.md b/README.md index b717ee6..ddedbde 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# planetarium +# planetarium🪐 -Planetarium is a benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. +Planetarium🪐 is a [dataset](https://huggingface.co/datasets/BatsResearch/planetarium) and benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. We developed a robust method for comparing PDDL problem descriptions using graph isomorphism. ## Installation To install the `planetarium` package, you can use the following command: @@ -37,3 +37,41 @@ from planetarium import evaluate evaluate.evaluate(gt_pddl_str, pred_pddl_str) ``` The supported domains are `blocksworld` and `gripper` domains. + +## Dataset +The main page for the dataset can be found [here](https://huggingface.co/datasets/BatsResearch/planetarium). + +Here is an example of how to load the dataset: +```python +from datasets import load_dataset + +dataset = load_dataset("BatsResearch/planetarium") +``` + +You can reporduce the dataset, the splits, and a report by running the following command: +```bash +python dataset_generator.py -c dataset_config.yaml +``` + +By modifying the `dataset_config.yaml` file, you can change the dataset splits, the number of samples, and produce even more examples! + +Here is a summary of the types of PDDL problems in the dataset: + +### Dataset Report +Total number of problems: $132,037$. + +#### Abstractness Split +| Init | Goal | blocksworld | gripper | +|:---:|:---:|---:|---:| +| abstract | abstract | $23,144$ | $10,632$ | +| abstract | explicit | $23,086$ | $9,518$ | +| explicit | abstract | $23,087$ | $10,313$ | +| explicit | explicit | $23,033$ | $9,224$ | +#### Size Splits (Number of Propositions in Ground Truth) +| Num. of Propositions | blocksworld | gripper | +|:---:|---:|---:| +| $0$-$20$ | $1,012$ | $379$ | +| $20$-$40$ | $10,765$ | $2,112$ | +| $40$-$60$ | $50,793$ | $9,412$ | +| $60$-$80$ | $26,316$ | $25,346$ | +| $80$-inf | $3,464$ | $2,438$ | \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index ab42d1a..0e41028 100644 --- a/evaluate.py +++ b/evaluate.py @@ -17,8 +17,6 @@ from planetarium import builder, graph, metric, oracle import llm_planner as llmp -from utils import apply_template - HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") @@ -82,8 +80,7 @@ def plan( context = [] for example_problem in example_problems: context.extend( - apply_template( - example_problem, + example_problem.apply_template( domain_prompt, problem_prompt, ) @@ -91,8 +88,7 @@ def plan( if isinstance(problem, llmp.PlanningProblem): messages = [ - apply_template( - problem, + problem.apply_template( domain_prompt, problem_prompt, include_answer=False, @@ -100,8 +96,7 @@ def plan( ] else: messages = [ - apply_template( - p, + p.apply_template( domain_prompt, problem_prompt, include_answer=False, diff --git a/finetune.py b/finetune.py index 773142e..8f95616 100644 --- a/finetune.py +++ b/finetune.py @@ -26,7 +26,6 @@ import tqdm as tqdm import llm_planner as llmp -from utils import apply_template from accelerate import Accelerator @@ -137,8 +136,7 @@ def preprocess( inputs = [ strip( tokenizer.apply_chat_template( - apply_template( - llmp.PlanningProblem(nl, d, p), + llmp.PlanningProblem(nl, d, p).apply_template( domain_prompt, problem_prompt, ), diff --git a/llm_planner.py b/llm_planner.py index 3aece5a..19462a3 100644 --- a/llm_planner.py +++ b/llm_planner.py @@ -31,6 +31,39 @@ def __init__( self.domain = domain self.problem = problem + def apply_template( + self, + domain_prompt: str = "", + problem_prompt: str = "", + include_answer: bool = True, + ) -> list[dict[str, str]]: + """Apply problem template to the problem. + + Args: + domain_prompt (str, optional): How to prompt the domain. Defaults to "". + problem_prompt (str, optional): How to prompt the problem. Defaults to "". + include_answer (bool, optional): Whether to include the answer. Defaults to True. + + Returns: + list[dict[str, str]]: Problem prompt. + """ + return [ + { + "role": "user", + "content": f"{problem_prompt} {self.natural_language} " + + f"{domain_prompt}\n{self.domain}\n", + }, + ] + ( + [ + { + "role": "assistant", + "content": " " + self.problem, + }, + ] + if include_answer + else [] + ) + class Planner(abc.ABC): @abc.abstractmethod diff --git a/planetarium/graph.py b/planetarium/graph.py index 2754a27..d4bf4bd 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -4,6 +4,8 @@ import enum from functools import cached_property +import matplotlib.pyplot as plt +import networkx as nx import rustworkx as rx @@ -360,6 +362,40 @@ def __eq__(self, other: "PlanGraph") -> bool: and self.domain == other.domain ) + def plot(self, fig: plt.Figure | None = None) -> plt.Figure: + """Generate a plot of the graph, sorted by topological generation. + + Args: + fig (plt.Figure | None, optional): The figure to plot on. Defaults + to None. + + Returns: + plt.Figure: The figure containing the plot. + """ + # rx has no plotting functionality + nx_graph = nx.MultiDiGraph() + nx_graph.add_edges_from( + [(u.node, v.node, {"data": edge}) for u, v, edge in self.edges] + ) + + for layer, nodes in enumerate(nx.topological_generations(nx_graph)): + for node in nodes: + nx_graph.nodes[node]["layer"] = layer + + pos = nx.multipartite_layout( + nx_graph, + align="horizontal", + subset_key="layer", + scale=-1, + ) + + if fig is None: + fig = plt.figure() + + nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) + + return fig + class SceneGraph(PlanGraph): """ @@ -523,8 +559,6 @@ def goal_predicates(self) -> list[dict[str, Any]]: return predicates - - @cached_property def _decompose(self) -> tuple[SceneGraph, SceneGraph]: """ diff --git a/poetry.lock b/poetry.lock index c837a33..bec858a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -243,33 +243,33 @@ test = ["scipy"] [[package]] name = "black" -version = "23.12.1" +version = "24.4.2" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, + {file = "black-24.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dd1b5a14e417189db4c7b64a6540f31730713d173f0b63e55fabd52d61d8fdce"}, + {file = "black-24.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e537d281831ad0e71007dcdcbe50a71470b978c453fa41ce77186bbe0ed6021"}, + {file = "black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaea3008c281f1038edb473c1aa8ed8143a5535ff18f978a318f10302b254063"}, + {file = "black-24.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:7768a0dbf16a39aa5e9a3ded568bb545c8c2727396d063bbaf847df05b08cd96"}, + {file = "black-24.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:257d724c2c9b1660f353b36c802ccece186a30accc7742c176d29c146df6e474"}, + {file = "black-24.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bdde6f877a18f24844e381d45e9947a49e97933573ac9d4345399be37621e26c"}, + {file = "black-24.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e151054aa00bad1f4e1f04919542885f89f5f7d086b8a59e5000e6c616896ffb"}, + {file = "black-24.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:7e122b1c4fb252fd85df3ca93578732b4749d9be076593076ef4d07a0233c3e1"}, + {file = "black-24.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:accf49e151c8ed2c0cdc528691838afd217c50412534e876a19270fea1e28e2d"}, + {file = "black-24.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88c57dc656038f1ab9f92b3eb5335ee9b021412feaa46330d5eba4e51fe49b04"}, + {file = "black-24.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be8bef99eb46d5021bf053114442914baeb3649a89dc5f3a555c88737e5e98fc"}, + {file = "black-24.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:415e686e87dbbe6f4cd5ef0fbf764af7b89f9057b97c908742b6008cc554b9c0"}, + {file = "black-24.4.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf10f7310db693bb62692609b397e8d67257c55f949abde4c67f9cc574492cc7"}, + {file = "black-24.4.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:98e123f1d5cfd42f886624d84464f7756f60ff6eab89ae845210631714f6db94"}, + {file = "black-24.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a85f2cb5e6799a9ef05347b476cce6c182d6c71ee36925a6c194d074336ef8"}, + {file = "black-24.4.2-cp38-cp38-win_amd64.whl", hash = "sha256:b1530ae42e9d6d5b670a34db49a94115a64596bc77710b1d05e9801e62ca0a7c"}, + {file = "black-24.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:37aae07b029fa0174d39daf02748b379399b909652a806e5708199bd93899da1"}, + {file = "black-24.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da33a1a5e49c4122ccdfd56cd021ff1ebc4a1ec4e2d01594fef9b6f267a9e741"}, + {file = "black-24.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef703f83fc32e131e9bcc0a5094cfe85599e7109f896fe8bc96cc402f3eb4b6e"}, + {file = "black-24.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:b9176b9832e84308818a99a561e90aa479e73c523b3f77afd07913380ae2eab7"}, + {file = "black-24.4.2-py3-none-any.whl", hash = "sha256:d36ed1124bb81b32f8614555b34cc4259c3fbc7eec17870e8ff8ded335b58d8c"}, + {file = "black-24.4.2.tar.gz", hash = "sha256:c872b53057f000085da66a19c55d68f6f8ddcac2642392ad3a355878406fbd4d"}, ] [package.dependencies] @@ -1214,13 +1214,13 @@ files = [ [[package]] name = "ipython" -version = "8.20.0" +version = "8.25.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.20.0-py3-none-any.whl", hash = "sha256:bc9716aad6f29f36c449e30821c9dd0c1c1a7b59ddcc26931685b87b4c569619"}, - {file = "ipython-8.20.0.tar.gz", hash = "sha256:2f21bd3fc1d51550c89ee3944ae04bbc7bc79e129ea0937da6e6c68bfdbf117a"}, + {file = "ipython-8.25.0-py3-none-any.whl", hash = "sha256:53eee7ad44df903a06655871cbab66d156a051fd86f3ec6750470ac9604ac1ab"}, + {file = "ipython-8.25.0.tar.gz", hash = "sha256:c6ed726a140b6e725b911528f80439c534fac915246af3efc39440a6b0f9d716"}, ] [package.dependencies] @@ -1229,24 +1229,26 @@ decorator = "*" exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} jedi = ">=0.16" matplotlib-inline = "*" -pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\" and sys_platform != \"emscripten\""} prompt-toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack-data = "*" -traitlets = ">=5" +traitlets = ">=5.13.0" +typing-extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} [package.extras] -all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.23)", "pandas", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +all = ["ipython[black,doc,kernel,matplotlib,nbconvert,nbformat,notebook,parallel,qtconsole]", "ipython[test,test-extra]"] black = ["black"] -doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +doc = ["docrepr", "exceptiongroup", "intersphinx-registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli", "typing-extensions"] kernel = ["ipykernel"] +matplotlib = ["matplotlib"] nbconvert = ["nbconvert"] nbformat = ["nbformat"] notebook = ["ipywidgets", "notebook"] parallel = ["ipyparallel"] qtconsole = ["qtconsole"] test = ["pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"] -test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath", "trio"] +test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] [[package]] name = "jedi" @@ -1643,13 +1645,13 @@ dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setupto [[package]] name = "matplotlib-inline" -version = "0.1.6" +version = "0.1.7" description = "Inline Matplotlib backend for Jupyter" optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, - {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, + {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, + {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, ] [package.dependencies] @@ -2433,18 +2435,18 @@ xml = ["lxml (>=4.9.2)"] [[package]] name = "parso" -version = "0.8.3" +version = "0.8.4" description = "A Python Parser" optional = false python-versions = ">=3.6" files = [ - {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, - {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, + {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, + {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, ] [package.extras] -qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] -testing = ["docopt", "pytest (<6.0.0)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["docopt", "pytest"] [[package]] name = "pathspec" @@ -2607,18 +2609,19 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.1.0" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +version = "4.2.2" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.1.0-py3-none-any.whl", hash = "sha256:11c8f37bcca40db96d8144522d925583bdb7a31f7b0e37e3ed4318400a8e2380"}, - {file = "platformdirs-4.1.0.tar.gz", hash = "sha256:906d548203468492d432bcb294d4bc2fff751bf84971fbb2c10918cc206ee420"}, + {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, + {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, ] [package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +type = ["mypy (>=1.8)"] [[package]] name = "pluggy" @@ -2666,13 +2669,13 @@ starlette = ">=0.30.0,<1.0.0" [[package]] name = "prompt-toolkit" -version = "3.0.43" +version = "3.0.47" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.43-py3-none-any.whl", hash = "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6"}, - {file = "prompt_toolkit-3.0.43.tar.gz", hash = "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d"}, + {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, + {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, ] [package.dependencies] @@ -4115,18 +4118,18 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.14.1" +version = "5.14.3" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" files = [ - {file = "traitlets-5.14.1-py3-none-any.whl", hash = "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74"}, - {file = "traitlets-5.14.1.tar.gz", hash = "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e"}, + {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, + {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, ] [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] [[package]] name = "transformers" @@ -4964,4 +4967,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b50aadf7670601b4ee8737969a1643d6c7e84684ea9b7c90796b9c0fc438a139" +content-hash = "26a2f1b6a41b8124ff9e7eba19f5fcf7fbc9de7e339b71c71ddf16baf261e295" diff --git a/pyproject.toml b/pyproject.toml index 2fb5618..d1b2d6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,15 +15,16 @@ pddl = {git = "https://github.com/maxzuo/pddl.git"} pyyaml = "^6.0.1" jinja2 = "^3.1.4" rustworkx = "^0.14.2" +matplotlib = "^3.9.0" [tool.poetry.group.dev.dependencies] ruff = "^0.1.7" pytest = "^7.4.3" mypy = "^1.7.1" -black = {extras = ["jupyter"], version = "^23.11.0"} pytest-cov = "^4.1.0" pytest-timeout = "^2.2.0" +black = {extras = ["jupyter"], version = "^24.4.2"} [tool.poetry.group.all.dependencies] @@ -34,7 +35,6 @@ datasets = "^2.20.0" peft = "^0.11.1" trl = "^0.9.4" bitsandbytes = "^0.43.1" -matplotlib = "^3.9.0" openai = "^1.35.3" [build-system] diff --git a/utils.py b/utils.py deleted file mode 100644 index a2d6269..0000000 --- a/utils.py +++ /dev/null @@ -1,72 +0,0 @@ -import matplotlib.pyplot as plt -import networkx as nx - -from planetarium import graph, oracle -import llm_planner as llmp - - -def apply_template( - problem: llmp.PlanningProblem, - domain_prompt: str = "", - problem_prompt: str = "", - include_answer: bool = True, -) -> list[dict[str, str]]: - """Apply problem template to the problem. - - Args: - problem(llmp.PlanningProblem): The problem to apply the template to. - domain_prompt (str, optional): How to prompt the domain. Defaults to "". - problem_prompt (str, optional): How to prompt the problem. Defaults to "". - include_answer (bool, optional): Whether to include the answer. Defaults to True. - - Returns: - list[dict[str, str]]: Problem prompt. - """ - return [ - { - "role": "user", - "content": f"{problem_prompt} {problem.natural_language} " - + f"{domain_prompt}\n{problem.domain}\n", - }, - ] + ( - [ - { - "role": "assistant", - "content": " " + problem.problem, - }, - ] - if include_answer - else [] - ) - - -def plot(graph: graph.PlanGraph, reduce: bool = False): - """Plot a graph representation of the PDDL description. - - Args: - graph (graph.PlanGraph): The graph to plot. - already_reduced (bool, optional): Whether the graph is already reduced. - Defaults to False. - """ - if reduce: - graph = oracle.reduce(graph, validate=False) - # rx has no plotting functionality - - nx_graph = nx.MultiDiGraph() - nx_graph.add_edges_from([(u.node, v.node, {"data":edge}) for u, v, edge in graph.edges]) - - for layer, nodes in enumerate(nx.topological_generations(nx_graph)): - for node in nodes: - nx_graph.nodes[node]["layer"] = layer - - pos = nx.multipartite_layout( - nx_graph, - align="horizontal", - subset_key="layer", - scale=-1, - ) - - fig = plt.figure() - nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) - - return fig