Skip to content

Commit

Permalink
adjust domain and problem parsers to support all numeric fluents
Browse files Browse the repository at this point in the history
  • Loading branch information
francescofuggitti committed Oct 12, 2023
1 parent fd36af7 commit 16b6e5b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 18 deletions.
38 changes: 32 additions & 6 deletions pddl/parser/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@
from pddl.helpers.base import assert_
from pddl.logic.base import And, ExistsCondition, ForallCondition, Imply, Not, OneOf, Or
from pddl.logic.effects import AndEffect, Forall, When
from pddl.logic.functions import Decrease
from pddl.logic.functions import Decrease, Divide
from pddl.logic.functions import EqualTo as FunctionEqualTo
from pddl.logic.functions import (
Function,
FunctionExpression,
GreaterEqualThan,
GreaterThan,
Increase,
LesserEqualThan,
LesserThan,
Minus,
NumericFunction,
NumericValue,
Plus,
Times,
TotalCost,
)
from pddl.logic.predicates import DerivedPredicate, EqualTo, Predicate
Expand All @@ -51,7 +56,7 @@ def __init__(self, *args, **kwargs):

self._constants_by_name: Dict[str, Constant] = {}
self._predicates_by_name: Dict[str, Predicate] = {}
self._functions_by_name: Dict[str, Function] = {}
self._functions_by_name: Dict[str, FunctionExpression] = {}
self._current_parameters_by_name: Dict[str, Variable] = {}
self._requirements: Set[str] = set()
self._extended_requirements: Set[str] = set()
Expand Down Expand Up @@ -347,15 +352,36 @@ def atomic_function_skeleton(self, args):
return TotalCost()
function_name = args[1]
variables = self._formula_skeleton(args)
return Function(function_name, *variables)
return NumericFunction(function_name, *variables)

def f_exp(self, args):
"""Process the 'f_exp' rule."""
if len(args) == 1:
if isinstance(args[0], (int, float)):
return NumericValue(args[0])
return args[0]
op = None
if args[1] == Symbols.MINUS.value:
op = Minus
if args[1] == Symbols.PLUS.value:
op = Plus
if args[1] == Symbols.TIMES.value:
op = Times
if args[1] == Symbols.DIVIDE.value:
op = Divide
return (
op(*args[2:-1])
if op is not None
else PDDLParsingError("Operator not recognized")
)

def f_head(self, args):
"""Process the 'f_head' rule."""
if len(args) == 1:
return args[0]
return NumericFunction(args[0])
function_name = args[1]
variables = [Variable(x, {}) for x in args[2:-1]]
return Function(function_name, *variables)
return NumericFunction(function_name, *variables)

def typed_list_name(self, args) -> Dict[name, Optional[name]]:
"""Process the 'typed_list_name' rule."""
Expand Down
52 changes: 40 additions & 12 deletions pddl/parser/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
from pddl.exceptions import PDDLParsingError
from pddl.helpers.base import assert_
from pddl.logic.base import And, Not
from pddl.logic.functions import Divide
from pddl.logic.functions import EqualTo as FunctionEqualTo
from pddl.logic.functions import (
Function,
GreaterEqualThan,
GreaterThan,
LesserEqualThan,
LesserThan,
Metric,
Minus,
NumericFunction,
NumericValue,
Plus,
Times,
)
from pddl.logic.predicates import EqualTo, Predicate
from pddl.logic.terms import Constant, Variable
Expand Down Expand Up @@ -111,11 +116,11 @@ def init_el(self, args):
return args[0]
elif args[1] == Symbols.EQUAL.value:
if isinstance(args[2], list) and len(args[2]) == 1:
return FunctionEqualTo(*args[2], args[3])
return FunctionEqualTo(*args[2], NumericValue(args[3]))
elif not isinstance(args[2], list):
return FunctionEqualTo(args[2], args[3])
return FunctionEqualTo(args[2], NumericValue(args[3]))
else:
funcs = [FunctionEqualTo(x, args[3]) for x in args[2]]
funcs = [FunctionEqualTo(x, NumericValue(args[3])) for x in args[2]]
return funcs

def literal_name(self, args):
Expand All @@ -130,8 +135,10 @@ def literal_name(self, args):
def basic_function_term(self, args):
"""Process the 'basic_function_term' rule."""
if len(args) == 1:
return Function(args[0])
return [Function(x) for x in args[1:-1]]
return NumericFunction(args[0])
function_name = args[1]
objects = [Constant(x) for x in args[2:-1]]
return NumericFunction(function_name, *objects)

def goal(self, args):
"""Process the 'goal' rule."""
Expand Down Expand Up @@ -186,19 +193,40 @@ def atomic_formula_name(self, args):
def f_head(self, args):
"""Process the 'f_head' rule."""
if len(args) == 1:
return args[0]
return NumericFunction(args[0])
function_name = args[1]
variables = [Variable(x, {}) for x in args[2:-1]]
return Function(function_name, *variables)
return NumericFunction(function_name, *variables)

def metric_spec(self, args):
"""Process the 'metric_spec' rule."""
if isinstance(args[3], list) and len(args[3]) == 1:
return "metric", Metric(*args[3], args[2])
elif not isinstance(args[2], list):
if args[2] == Symbols.MINIMIZE.value:
return "metric", Metric(args[3], args[2])
elif args[2] == Symbols.MAXIMIZE.value:
return "metric", Metric(args[3], args[2])
else:
raise ParseError
raise PDDLParsingError(f"Unknown metric operator: {args[2]}")

def metric_f_exp(self, args):
"""Process the 'metric_f_exp' rule."""
if len(args) == 1:
if isinstance(args[0], (int, float)):
return NumericValue(args[0])
return args[0]
op = None
if args[1] == Symbols.MINUS.value:
op = Minus
if args[1] == Symbols.PLUS.value:
op = Plus
if args[1] == Symbols.TIMES.value:
op = Times
if args[1] == Symbols.DIVIDE.value:
op = Divide
return (
op(*args[2:-1])
if op is not None
else PDDLParsingError("Operator not recognized")
)


_problem_parser_lark = PROBLEM_GRAMMAR_FILE.read_text()
Expand Down

0 comments on commit 16b6e5b

Please sign in to comment.