Skip to content

Commit 233b568

Browse files
authored
Implements partial refactor to allow for Boolean & Numeric interop (via FormulaTerm wrapper) (#96)
These changes do not complete the full SMT-like refactor previously discussed, but *do* allow for interoperability between `Formula` and `Term` objects and treats the boolean sort differently so that it can be used in numeric arithmetic (as needed for standard patterns in RDDL). There are a handful hacks that were required to preserve the FSTRIPS *Effect syntax that should be fixed in the future. The main changes include: 1. Changes the Boolean sort to be 0/1 valued and a child of the Naturals sort when the `Arithmetic` theory is included in a language (and a standalone child of `Object` otherwise) 2. Adds the `FormulaTerm` wrapper class, which is used as a container for `Formula` objects that must be treated as `Term` objects for arithmetic, etc. 3. All of `Predicate`, `Formula`, `Term`, `Function` objects are now associated with a `FirstOrderLanguage` (previously only `Function` and `Term` objects had a language property). Languages are inherited up from "subterms", when, for example, a `CompoundFormula` is constructed. This is essential for being able to construct `FormulaTerm` wrappers when needed, since `Term` objects always need an associated language. 4. Implements a set of tests (primarily focused on RDDL use cases) that adds to the existing test suite The contributing commit messages follow: * implements basic functionality for FormulaTerm wrapper function and makes the Boolean types a core builtin * completes partial implementation of Term and Formula type conversion with wrappers * implements major refactor components. includes modifications to tests to reflect API changes to FSTRIPS *Effect(s) * fixes bug where multiple writes without reset would dump repeated obj domain lists in rddl instance * implements basic RDDL writer integration test & makes it pass * completes implemenation of academic_advising rddl writer integration test (passing) * un-breaks strips *Effect building API with "`Pass` replaces `Tautology`" workaround This is a bit of an ugly workaround, but it will work for now. we have a special `Pass` type that is EXACTLY only ever used when we need to construct or check against an FSTRIPS effect that is not a conditional effect. Previously, the condition had been set as a default parameter to `Tautology` for any regular FSTRIPS effect. However, since *Effects are built outside of the context of a language, this did not work after we needed a language for Tautology and Contradiction (so that other Formula types could inherit them). There are a few more permanent approaches to this. We could either break the API and build Effects in the context of a language, or we could figure out another way to have a universal Tautology that somehow does not need to be in the context of a language. Notably this problem will STILL BE AN ISSUE if we go to a fully boolean-valued Function refactor (eliminating Predicates, etc), rather than the current partial refactor that involves wrappers between Formula and Term. * updates tests related to Term/Formula interop and arithmetic with Booleans * makes some code style fixes * re-enables a test that is now passing again after pull from upstream devel * adds option for 2018-style RDDL file format writing -- maintains default to pre-2018 * adds condition to avoid adding :numeric-fluents simply due to the Boolean sort being attached to the language * fixes stdout name issue on macos
1 parent c16ef4a commit 233b568

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+880
-282
lines changed

src/tarski/analysis/csp_schema.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
collect_effect_free_parameters
99
from ..grounding.common import StateVariableLite
1010
from ..syntax import QuantifiedFormula, Quantifier, Contradiction, CompoundFormula, Atom, CompoundTerm, \
11-
is_neg, symref, Constant, Variable, Tautology, top
11+
is_neg, symref, Constant, Variable, Tautology
1212
from ..syntax.ops import collect_unique_nodes, flatten
1313
from ..syntax.transform import to_prenex_negation_normal_form
1414

@@ -127,7 +127,7 @@ def compile_schema_csp(self, action, simplifier):
127127
if precondition is False:
128128
return None
129129
if precondition is True:
130-
precondition = top
130+
precondition = self.lang.top()
131131

132132
csp = CSPInformation()
133133
csp.parameter_index = [self.variable(p, csp, "param") for p in action.parameters]

src/tarski/fol.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from . import errors as err
88
from .errors import UndefinedElement
9-
from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval
9+
from .syntax import Function, Constant, Variable, Sort, inclusion_closure, Predicate, Interval, Tautology, Contradiction
10+
from .syntax.formulas import FormulaTerm
1011
from .syntax.algebra import Matrix
1112
from . import modules
1213

@@ -160,6 +161,14 @@ def variable(self, name: str, sort: Union[Sort, str]):
160161
sort = self._retrieve_sort(sort)
161162
return Variable(name, sort)
162163

164+
#todo: [John Peterson] not ideal to have to add this just to be able to fix booleans done 2 ways
165+
def change_parent(self, sort: Sort, parent: Sort):
166+
if parent.language is not self:
167+
raise err.LanguageError("Tried to set as parent a sort from a different language")
168+
169+
self.immediate_parent[sort] = parent
170+
self.ancestor_sorts[sort].update(inclusion_closure(parent))
171+
163172
def set_parent(self, sort: Sort, parent: Sort):
164173
if parent.language is not self:
165174
raise err.LanguageError("Tried to set as parent a sort from a different language")
@@ -250,6 +259,12 @@ def _check_name_not_defined(self, name, where, exception):
250259
if name in self._global_index:
251260
raise err.DuplicateDefinition(name, self._global_index[name])
252261

262+
def top(self):
263+
return Tautology(self)
264+
265+
def bot(self):
266+
return Contradiction(self)
267+
253268
def predicate(self, name: str, *args):
254269
self._check_name_not_defined(name, self._predicates, err.DuplicatePredicateDefinition)
255270

@@ -333,6 +348,12 @@ def __str__(self):
333348
f"{len(self._functions)} functions and {len(self.constants())} constants"
334349
__repr__ = __str__
335350

351+
#todo: [John Peterson] I'm not sure if this should be here. We
352+
#need access to the language's sorts to be able to inject the
353+
#necessary special boolean sort. Reevaluate as a todo.
354+
def generate_formula_term(self, formula):
355+
return FormulaTerm(formula)
356+
336357
def register_operator_handler(self, operator, t1, t2, handler):
337358
self._operators[(operator, t1, t2)] = handler
338359

src/tarski/fstrips/fstrips.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from .. import theories as ths
77
from .errors import InvalidEffectError
88

9-
109
class BaseEffect:
1110
""" A base class for all FSTRIPS effects, which might have an (optional) condition. """
1211
def __init__(self, condition):

src/tarski/fstrips/manipulation/simplify.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
from ...evaluators.simple import evaluate
88
from ...grounding.ops import approximate_symbol_fluency
99
from ...syntax.terms import Constant, Variable, CompoundTerm
10-
from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Connective, is_neg, \
11-
Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists
10+
from ...syntax.formulas import CompoundFormula, QuantifiedFormula, Atom, Pass, Tautology,\
11+
Contradiction, Connective, is_neg, Quantifier, unwrap_conjunction_or_atom, is_eq_atom, land, exists
1212
from ...syntax.transform.substitutions import substitute_expression
1313
from ...syntax.util import get_symbols
1414
from ...syntax.walker import FOLWalker
1515
from ...syntax.ops import flatten
1616
from ...syntax import symref
1717

1818

19-
def bool_to_expr(val):
19+
def bool_to_expr(val, lang):
2020
if not isinstance(val, bool):
2121
return val
22-
return Tautology() if val else Contradiction()
22+
return Tautology(lang) if val else Contradiction(lang)
2323

2424

2525
class Simplify:
@@ -83,7 +83,7 @@ def simplify(self, inplace=False, remove_unused_symbols=False):
8383
def simplify_action(self, action, inplace=False):
8484
simple = action if inplace else copy.deepcopy(action)
8585
simple.precondition = self.simplify_expression(simple.precondition, inplace=True)
86-
if simple.precondition in (False, Contradiction):
86+
if simple.precondition is False or isinstance(simple.precondition, Contradiction):
8787
return None
8888

8989
# Filter out those effects that are None, e.g. because they are not applicable:
@@ -107,6 +107,9 @@ def simplify_expression(self, node, inplace=True):
107107
if isinstance(node, Tautology):
108108
return True
109109

110+
if isinstance(node, Pass):
111+
return True
112+
110113
if isinstance(node, (CompoundTerm, Atom)):
111114
node.subterms = [self.simplify_expression(st) for st in node.subterms]
112115
if not self.node_can_be_statically_evaluated(node):
@@ -157,14 +160,14 @@ def simplify_effect(self, effect, inplace=True):
157160
effect = effect if inplace else copy.deepcopy(effect)
158161

159162
if isinstance(effect, (AddEffect, DelEffect)):
160-
effect.condition = bool_to_expr(self.simplify_expression(effect.condition))
163+
effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language)
161164
if isinstance(effect.condition, Contradiction):
162165
return None
163166
effect.atom = self.simplify_expression(effect.atom)
164167
return effect
165168

166169
if isinstance(effect, FunctionalEffect):
167-
effect.condition = bool_to_expr(self.simplify_expression(effect.condition))
170+
effect.condition = bool_to_expr(self.simplify_expression(effect.condition), self.problem.language)
168171
if isinstance(effect.condition, Contradiction):
169172
return None
170173
effect.lhs = self.simplify_expression(effect.lhs)

src/tarski/fstrips/representation.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .problem import Problem
66
from . import fstrips as fs
77
from ..syntax import Formula, CompoundTerm, Atom, CompoundFormula, QuantifiedFormula, is_and, is_neg, exists, symref,\
8-
VariableBinding, Constant, Tautology, land, Term
8+
VariableBinding, Constant, Tautology, land, Term, Pass
99
from ..syntax.ops import collect_unique_nodes, flatten, free_variables, all_variables
1010
from ..syntax.sorts import compute_signature_bindings
1111
from ..syntax.transform.substitutions import enumerate_substitutions
@@ -95,7 +95,7 @@ def transform_to_strips(what: Union[Problem, Action]):
9595

9696
def is_atomic_effect(eff: BaseEffect):
9797
""" An effect is atomic if it is a single, unconditional effect. """
98-
return isinstance(eff, SingleEffect) and isinstance(eff.condition, Tautology)
98+
return isinstance(eff, SingleEffect) and isinstance(eff.condition, (Tautology, Pass))
9999

100100

101101
def is_propositional_effect(eff: BaseEffect):
@@ -123,10 +123,10 @@ def compute_effect_set_conflicts(effects):
123123
if not is_atomic_effect(eff) or not is_propositional_effect(eff):
124124
raise RepresentationError(f"Don't know how to compute conflicts for effect {eff}")
125125
pol = isinstance(eff, AddEffect) # i.e. polarity will be true if add effect, false otherwise
126-
prev = polarities.get(eff.atom, None)
126+
prev = polarities.get(symref(eff.atom), None)
127127
if prev is not None and prev != pol:
128-
conflicts.add(eff.atom)
129-
polarities[eff.atom] = pol
128+
conflicts.add(symref(eff.atom))
129+
polarities[symref(eff.atom)] = pol
130130
return conflicts
131131

132132

@@ -220,9 +220,9 @@ def collect_literals_from_conjunction(phi: Formula) -> Optional[Set[Tuple[Atom,
220220

221221
def _collect_literals_from_conjunction(f, literals: Set[Tuple[Atom, bool]]):
222222
if isinstance(f, Atom):
223-
literals.add((f, True))
223+
literals.add((symref(f), True))
224224
elif is_neg(f) and isinstance(f.subformulas[0], Atom):
225-
literals.add((f.subformulas[0], False))
225+
literals.add((symref(f.subformulas[0]), False))
226226
elif is_and(f):
227227
for sub in f.subformulas:
228228
if not _collect_literals_from_conjunction(sub, literals):
@@ -465,7 +465,7 @@ def compile_action_negated_preconditions_away(action: Action, negpreds, inplace=
465465
if not isinstance(eff, SingleEffect):
466466
raise RepresentationError(f"Cannot compile away negated conditions for effect '{eff}'")
467467

468-
if not isinstance(eff.condition, Tautology):
468+
if not isinstance(eff.condition, (Tautology, Pass)):
469469
eff.condition = compile_away_formula_negated_literals(eff.condition, negpreds, inplace=True)
470470

471471
return action
@@ -567,7 +567,7 @@ def expand_universal_effect(effect):
567567
if not isinstance(effect, UniversalEffect):
568568
return [effect]
569569

570-
assert isinstance(effect.condition, Tautology) # TODO Lift this restriction
570+
assert isinstance(effect.condition, (Tautology, Pass)) # TODO Lift this restriction
571571
expanded = []
572572
for subst in enumerate_substitutions(effect.variables):
573573
for sub in effect.effects:

src/tarski/fstrips/walker.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,13 @@ def visit_effect(self, effect, inplace=True):
121121
return self.visit(effect)
122122

123123
def visit_expression(self, node, inplace=True):
124-
from ..syntax import CompoundFormula, QuantifiedFormula, Atom, Tautology, Contradiction, Constant, Variable,\
125-
CompoundTerm, IfThenElse # pylint: disable=import-outside-toplevel # Avoiding circular references
124+
# pylint: disable=import-outside-toplevel
125+
from ..syntax import CompoundFormula, QuantifiedFormula, Atom, \
126+
Tautology, Contradiction, Pass, Constant, Variable, \
127+
CompoundTerm, IfThenElse
126128
node = node if inplace else copy.deepcopy(node)
127129

128-
if isinstance(node, (Variable, Constant, Contradiction, Tautology)):
130+
if isinstance(node, (Variable, Constant, Contradiction, Tautology, Pass)):
129131
pass
130132

131133
elif isinstance(node, (CompoundTerm, Atom)):

src/tarski/io/_fstrips/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ...fstrips import FunctionalEffect
33
from ...fstrips.action import AdditiveActionCost, generate_zero_action_cost
44
from ...fstrips.representation import is_typed_problem
5-
from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol
5+
from ...syntax import Interval, CompoundTerm, Tautology, BuiltinFunctionSymbol, Pass
66
from ... import theories
77
from ...syntax.util import get_symbols
88
from ...theories import Theory
@@ -58,7 +58,7 @@ def get_requirements_string(problem):
5858
# Let's check now whether the problem has any predicate or function symbol *other than "total-cost"* which
5959
# has some arithmetic parameter or result. If so, we add the ":numeric-fluents" requirement.
6060
for symbol in get_symbols(problem.language, type_='all', include_builtin=False):
61-
if any(isinstance(s, Interval) for s in symbol.sort) and symbol.name != 'total-cost':
61+
if any((isinstance(s, Interval) and s.name != 'Boolean') for s in symbol.sort) and symbol.name != 'total-cost':
6262
requirements.add(":numeric-fluents")
6363

6464
return requirements
@@ -113,7 +113,7 @@ def process_cost_effects(effects):
113113
def process_cost_effect(eff):
114114
""" Check if the given effect is a cost effect. If it is, return the additive cost; if it is not, return None. """
115115
if isinstance(eff, FunctionalEffect) and isinstance(eff.lhs, CompoundTerm) and eff.lhs.symbol.name == "total-cost":
116-
if not isinstance(eff.condition, Tautology):
116+
if not isinstance(eff.condition, Pass):
117117
raise TarskiError(f'Don\'t know how to process conditional cost effects such as {eff}')
118118
if not isinstance(eff.rhs, CompoundTerm) or eff.rhs.symbol.name != BuiltinFunctionSymbol.ADD:
119119
raise TarskiError(f'Don\'t know how to process non-additive cost effects such as {eff}')

src/tarski/io/_fstrips/parser/lexer.py

+34-36
Original file line numberDiff line numberDiff line change
@@ -641,45 +641,45 @@ class fstripsLexer(Lexer):
641641
modeNames = [ "DEFAULT_MODE" ]
642642

643643
literalNames = [ "<INVALID>",
644-
"'('", "'define'", "')'", "'domain'", "':requirements'", "':types'",
645-
"'-'", "'either'", "':functions'", "':constants'", "':predicates'",
646-
"':parameters'", "':constraint'", "':condition'", "':event'",
647-
"'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'",
648-
"'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'",
649-
"'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='",
650-
"'<='", "'problem'", "':domain'", "':objects'", "':bounds'",
651-
"'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'",
652-
"':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'",
653-
"':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'",
654-
"'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'",
655-
"'always-within'", "'hold-during'", "'hold-after'", "'scale-up'",
644+
"'('", "'define'", "')'", "'domain'", "':requirements'", "':types'",
645+
"'-'", "'either'", "':functions'", "':constants'", "':predicates'",
646+
"':parameters'", "':constraint'", "':condition'", "':event'",
647+
"'#t'", "':derived'", "'assign'", "'*'", "'+'", "'/'", "'^'",
648+
"'max'", "'min'", "'sin'", "'cos'", "'sqrt'", "'tan'", "'acos'",
649+
"'asin'", "'atan'", "'exp'", "'abs'", "'>'", "'<'", "'='", "'>='",
650+
"'<='", "'problem'", "':domain'", "':objects'", "':bounds'",
651+
"'['", "'..'", "']'", "':goal'", "':constraints'", "'preference'",
652+
"':metric'", "'minimize'", "'maximize'", "'(total-time)'", "'is-violated'",
653+
"':terminal'", "':stage'", "'at-end'", "'always'", "'sometime'",
654+
"'within'", "'at-most-once'", "'sometime-after'", "'sometime-before'",
655+
"'always-within'", "'hold-during'", "'hold-after'", "'scale-up'",
656656
"'scale-down'", "'int'", "'float'", "'object'", "'number'" ]
657657

658658
symbolicNames = [ "<INVALID>",
659-
"REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS",
660-
"K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE",
661-
"K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T",
662-
"NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT",
659+
"REQUIRE_KEY", "K_AND", "K_NOT", "K_OR", "K_IMPLY", "K_EXISTS",
660+
"K_FORALL", "K_WHEN", "K_ACTION", "K_INCREASE", "K_DECREASE",
661+
"K_SCALEUP", "K_SCALEDOWN", "INT_T", "FLOAT_T", "OBJECT_T",
662+
"NUMBER_T", "NAME", "EXTNAME", "VARIABLE", "NUMBER", "LINE_COMMENT",
663663
"WHITESPACE", "K_INIT", "K_PRECONDITION", "K_EFFECT" ]
664664

665-
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
666-
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
667-
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
668-
"T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
669-
"T__26", "T__27", "T__28", "T__29", "T__30", "T__31",
670-
"T__32", "T__33", "T__34", "T__35", "T__36", "T__37",
671-
"T__38", "T__39", "T__40", "T__41", "T__42", "T__43",
672-
"T__44", "T__45", "T__46", "T__47", "T__48", "T__49",
673-
"T__50", "T__51", "T__52", "T__53", "T__54", "T__55",
674-
"T__56", "T__57", "T__58", "T__59", "T__60", "T__61",
675-
"T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT",
676-
"K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION",
677-
"K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN",
678-
"INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME",
679-
"DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE",
680-
"NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION",
681-
"K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I",
682-
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T",
665+
ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6",
666+
"T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13",
667+
"T__14", "T__15", "T__16", "T__17", "T__18", "T__19",
668+
"T__20", "T__21", "T__22", "T__23", "T__24", "T__25",
669+
"T__26", "T__27", "T__28", "T__29", "T__30", "T__31",
670+
"T__32", "T__33", "T__34", "T__35", "T__36", "T__37",
671+
"T__38", "T__39", "T__40", "T__41", "T__42", "T__43",
672+
"T__44", "T__45", "T__46", "T__47", "T__48", "T__49",
673+
"T__50", "T__51", "T__52", "T__53", "T__54", "T__55",
674+
"T__56", "T__57", "T__58", "T__59", "T__60", "T__61",
675+
"T__62", "T__63", "T__64", "REQUIRE_KEY", "K_AND", "K_NOT",
676+
"K_OR", "K_IMPLY", "K_EXISTS", "K_FORALL", "K_WHEN", "K_ACTION",
677+
"K_INCREASE", "K_DECREASE", "K_SCALEUP", "K_SCALEDOWN",
678+
"INT_T", "FLOAT_T", "OBJECT_T", "NUMBER_T", "NAME", "EXTNAME",
679+
"DIGIT", "LETTER", "ANY_CHAR_WO_HYPHEN", "ANY_CHAR", "VARIABLE",
680+
"NUMBER", "LINE_COMMENT", "WHITESPACE", "K_INIT", "K_PRECONDITION",
681+
"K_EFFECT", "A", "B", "C", "D", "E", "F", "G", "H", "I",
682+
"J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T",
683683
"U", "V", "W", "X", "Y", "Z" ]
684684

685685
grammarFileName = "fstrips.g4"
@@ -690,5 +690,3 @@ def __init__(self, input=None, output:TextIO = sys.stdout):
690690
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
691691
self._actions = None
692692
self._predicates = None
693-
694-

src/tarski/io/_fstrips/parser/listener.py

-2
Original file line numberDiff line numberDiff line change
@@ -1041,5 +1041,3 @@ def enterAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysCo
10411041
# Exit a parse tree produced by fstripsParser#AlternativeAlwaysConstraint.
10421042
def exitAlternativeAlwaysConstraint(self, ctx:fstripsParser.AlternativeAlwaysConstraintContext):
10431043
pass
1044-
1045-

0 commit comments

Comments
 (0)