diff --git a/pddl/parser/domain.py b/pddl/parser/domain.py index a26084d..2658e4d 100644 --- a/pddl/parser/domain.py +++ b/pddl/parser/domain.py @@ -23,15 +23,20 @@ from pddl.helpers.base import assert_ from pddl.logic.base import And, ExistsCondition, ForallCondition, Imply, Not, OneOf, Or from pddl.logic.effects import AndEffect, Forall, When -from pddl.logic.functions import Decrease +from pddl.logic.functions import Decrease, Divide from pddl.logic.functions import EqualTo as FunctionEqualTo from pddl.logic.functions import ( - Function, + FunctionExpression, GreaterEqualThan, GreaterThan, Increase, LesserEqualThan, LesserThan, + Minus, + NumericFunction, + NumericValue, + Plus, + Times, TotalCost, ) from pddl.logic.predicates import DerivedPredicate, EqualTo, Predicate @@ -51,7 +56,7 @@ def __init__(self, *args, **kwargs): self._constants_by_name: Dict[str, Constant] = {} self._predicates_by_name: Dict[str, Predicate] = {} - self._functions_by_name: Dict[str, Function] = {} + self._functions_by_name: Dict[str, FunctionExpression] = {} self._current_parameters_by_name: Dict[str, Variable] = {} self._requirements: Set[str] = set() self._extended_requirements: Set[str] = set() @@ -347,15 +352,36 @@ def atomic_function_skeleton(self, args): return TotalCost() function_name = args[1] variables = self._formula_skeleton(args) - return Function(function_name, *variables) + return NumericFunction(function_name, *variables) + + def f_exp(self, args): + """Process the 'f_exp' rule.""" + if len(args) == 1: + if isinstance(args[0], (int, float)): + return NumericValue(args[0]) + return args[0] + op = None + if args[1] == Symbols.MINUS.value: + op = Minus + if args[1] == Symbols.PLUS.value: + op = Plus + if args[1] == Symbols.TIMES.value: + op = Times + if args[1] == Symbols.DIVIDE.value: + op = Divide + return ( + op(*args[2:-1]) + if op is not None + else PDDLParsingError("Operator not recognized") + ) def f_head(self, args): """Process the 'f_head' rule.""" if len(args) == 1: - return args[0] + return NumericFunction(args[0]) function_name = args[1] variables = [Variable(x, {}) for x in args[2:-1]] - return Function(function_name, *variables) + return NumericFunction(function_name, *variables) def typed_list_name(self, args) -> Dict[name, Optional[name]]: """Process the 'typed_list_name' rule.""" diff --git a/pddl/parser/problem.py b/pddl/parser/problem.py index 5d01f81..cb1a836 100644 --- a/pddl/parser/problem.py +++ b/pddl/parser/problem.py @@ -20,14 +20,19 @@ from pddl.exceptions import PDDLParsingError from pddl.helpers.base import assert_ from pddl.logic.base import And, Not +from pddl.logic.functions import Divide from pddl.logic.functions import EqualTo as FunctionEqualTo from pddl.logic.functions import ( - Function, GreaterEqualThan, GreaterThan, LesserEqualThan, LesserThan, Metric, + Minus, + NumericFunction, + NumericValue, + Plus, + Times, ) from pddl.logic.predicates import EqualTo, Predicate from pddl.logic.terms import Constant, Variable @@ -111,11 +116,11 @@ def init_el(self, args): return args[0] elif args[1] == Symbols.EQUAL.value: if isinstance(args[2], list) and len(args[2]) == 1: - return FunctionEqualTo(*args[2], args[3]) + return FunctionEqualTo(*args[2], NumericValue(args[3])) elif not isinstance(args[2], list): - return FunctionEqualTo(args[2], args[3]) + return FunctionEqualTo(args[2], NumericValue(args[3])) else: - funcs = [FunctionEqualTo(x, args[3]) for x in args[2]] + funcs = [FunctionEqualTo(x, NumericValue(args[3])) for x in args[2]] return funcs def literal_name(self, args): @@ -130,8 +135,10 @@ def literal_name(self, args): def basic_function_term(self, args): """Process the 'basic_function_term' rule.""" if len(args) == 1: - return Function(args[0]) - return [Function(x) for x in args[1:-1]] + return NumericFunction(args[0]) + function_name = args[1] + objects = [Constant(x) for x in args[2:-1]] + return NumericFunction(function_name, *objects) def goal(self, args): """Process the 'goal' rule.""" @@ -186,19 +193,40 @@ def atomic_formula_name(self, args): def f_head(self, args): """Process the 'f_head' rule.""" if len(args) == 1: - return args[0] + return NumericFunction(args[0]) function_name = args[1] variables = [Variable(x, {}) for x in args[2:-1]] - return Function(function_name, *variables) + return NumericFunction(function_name, *variables) def metric_spec(self, args): """Process the 'metric_spec' rule.""" - if isinstance(args[3], list) and len(args[3]) == 1: - return "metric", Metric(*args[3], args[2]) - elif not isinstance(args[2], list): + if args[2] == Symbols.MINIMIZE.value: + return "metric", Metric(args[3], args[2]) + elif args[2] == Symbols.MAXIMIZE.value: return "metric", Metric(args[3], args[2]) else: - raise ParseError + raise PDDLParsingError(f"Unknown metric operator: {args[2]}") + + def metric_f_exp(self, args): + """Process the 'metric_f_exp' rule.""" + if len(args) == 1: + if isinstance(args[0], (int, float)): + return NumericValue(args[0]) + return args[0] + op = None + if args[1] == Symbols.MINUS.value: + op = Minus + if args[1] == Symbols.PLUS.value: + op = Plus + if args[1] == Symbols.TIMES.value: + op = Times + if args[1] == Symbols.DIVIDE.value: + op = Divide + return ( + op(*args[2:-1]) + if op is not None + else PDDLParsingError("Operator not recognized") + ) _problem_parser_lark = PROBLEM_GRAMMAR_FILE.read_text()