diff --git a/pddl/_validation.py b/pddl/_validation.py index d5f76fd..22e42f6 100644 --- a/pddl/_validation.py +++ b/pddl/_validation.py @@ -27,7 +27,7 @@ from pddl.logic.predicates import DerivedPredicate, EqualTo from pddl.logic.terms import Term from pddl.parser.symbols import Symbols -from pddl.requirements import Requirements +from pddl.requirements import Requirements, _extend_domain_requirements def validate(condition: bool, message: str = "") -> None: @@ -193,7 +193,8 @@ def __init__( @property def has_typing(self) -> bool: """Check if the typing requirement is specified.""" - return Requirements.TYPING in self._requirements + self._extended_requirements = _extend_domain_requirements(self._requirements) + return Requirements.TYPING in self._extended_requirements def _check_typing_requirement(self, type_tags: Collection[name_type]) -> None: """Check that the typing requirement is specified.""" diff --git a/pddl/parser/domain.py b/pddl/parser/domain.py index f0cf83b..d9999bc 100644 --- a/pddl/parser/domain.py +++ b/pddl/parser/domain.py @@ -28,7 +28,7 @@ from pddl.parser import DOMAIN_GRAMMAR_FILE, PARSERS_DIRECTORY from pddl.parser.symbols import Symbols from pddl.parser.typed_list_parser import TypedListParser -from pddl.requirements import Requirements +from pddl.requirements import Requirements, _extend_domain_requirements class DomainTransformer(Transformer): @@ -72,10 +72,7 @@ def domain_def(self, args): def requirements(self, args): """Process the 'requirements' rule.""" self._requirements = {Requirements(r[1:]) for r in args[2:-1]} - - self._extended_requirements = set(self._requirements) - if Requirements.STRIPS in self._requirements: - self._extended_requirements.update(Requirements.strips_requirements()) + self._extended_requirements = _extend_domain_requirements(self._requirements) return dict(requirements=self._requirements) diff --git a/pddl/requirements.py b/pddl/requirements.py index 22622d0..91a5ebc 100644 --- a/pddl/requirements.py +++ b/pddl/requirements.py @@ -13,7 +13,7 @@ """This module contains the definition of the PDDL requirements.""" import functools from enum import Enum -from typing import Set +from typing import AbstractSet, Set from pddl.parser.symbols import RequirementSymbols as RS @@ -36,15 +36,24 @@ class Requirements(Enum): NON_DETERMINISTIC = RS.NON_DETERMINISTIC.strip() @classmethod - def strips_requirements(cls) -> Set["Requirements"]: - """Get the STRIPS requirements.""" + def quantified_precondition_requirements(cls) -> Set["Requirements"]: + """Get the quantified precondition requirements.""" return { + Requirements.UNIVERSAL_PRECONDITION, + Requirements.EXISTENTIAL_PRECONDITION, + } + + @classmethod + def adl_requirements(cls) -> Set["Requirements"]: + """Get the ADL requirements.""" + return { + Requirements.STRIPS, Requirements.TYPING, Requirements.NEG_PRECONDITION, Requirements.DIS_PRECONDITION, Requirements.EQUALITY, Requirements.CONDITIONAL_EFFECTS, - } + }.union(cls.quantified_precondition_requirements()) def __str__(self) -> str: """Get the string representation.""" @@ -60,3 +69,17 @@ def __lt__(self, other): return self.value <= other.value else: return super().__lt__(other) + + +def _extend_domain_requirements( + requirements: AbstractSet[Requirements], +) -> Set[Requirements]: + """Extend the requirements with the domain requirements.""" + extended_requirements = set(requirements) + if Requirements.QUANTIFIED_PRECONDITION in requirements: + extended_requirements.update( + Requirements.quantified_precondition_requirements() + ) + if Requirements.ADL in requirements: + extended_requirements.update(Requirements.adl_requirements()) + return extended_requirements diff --git a/tests/test_parser/test_domain.py b/tests/test_parser/test_domain.py index 5019456..c193b65 100644 --- a/tests/test_parser/test_domain.py +++ b/tests/test_parser/test_domain.py @@ -121,6 +121,30 @@ def test_types_repetition_in_typed_lists_not_allowed() -> None: DomainParser()(domain_str) +def test_typing_requirement_under_other_domain_requirements() -> None: + """Check :typing requirement does not throw error if other domain requirements that includes it are detected.""" + domain_str = dedent( + """ +(define (domain test) + (:requirements :adl) + (:types a b c) + (:predicates + (predicate1 ?x - a) + (predicate2 ?x - b) + (predicate3 ?x - c) + ) + ) + """ + ) + + domain = DomainParser()(domain_str) + assert domain.types == { + "a": None, + "b": None, + "c": None, + } + + @pytest.mark.parametrize("keyword", TEXT_SYMBOLS - {Symbols.OBJECT.value}) def test_keyword_usage_not_allowed_as_name(keyword) -> None: """Check keywords usage as names is detected and a parsing error is raised."""