Skip to content

Commit

Permalink
Merge branch 'master' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Jul 21, 2024
2 parents 40dfaa7 + bf87103 commit c66a4c0
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 57 deletions.
2 changes: 1 addition & 1 deletion neuralogic/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.16"
__version__ = "0.7.19"
4 changes: 3 additions & 1 deletion neuralogic/core/builder/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ def __len__(self) -> int:
return len(self._groundings_list)

def neuralize(self, *, progress: bool = False):
self._to_list()

if self._groundings_list is not None:
return self._builder.neuralize(self._groundings.stream(), progress, len(self))
return self._builder.neuralize(jpype.java.util.ArrayList(self._groundings).stream(), progress, len(self))
if progress:
return self._builder.neuralize(self._groundings, progress, len(self))
return self._builder.neuralize(self._groundings, progress, 0)
30 changes: 29 additions & 1 deletion neuralogic/core/constructs/java_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jpype

from neuralogic import is_initialized, initialize
from neuralogic.core.constructs.factories import R
from neuralogic.core.constructs.term import Variable, Constant
from neuralogic.core.settings import SettingsProxy, Settings

Expand Down Expand Up @@ -307,7 +308,34 @@ def get_rule(self, rule):
else:
java_rule.setWeight(weight)

body_relation = [self.get_relation(relation, variable_factory) for relation in rule.body]
all_variables = {term for term in rule.head.terms if term is not Ellipsis and str(term)[0].isupper()}
body_relation = []
all_diff_index = []

for i, relation in enumerate(rule.body):
all_variables.update(term for term in relation.terms if term is not Ellipsis and str(term)[0].isupper())

if relation.predicate.special and relation.predicate.name == "alldiff":
found = False

for term in relation.terms:
if term is Ellipsis:
body_relation.append(R.special.alldiff(relation.terms))
all_diff_index.append(i)
found = True

break
if found:
continue

body_relation.append(self.get_relation(relation, variable_factory))

for index in all_diff_index:
terms = {term for term in body_relation[index].terms}
terms.update(term for term in all_variables)

body_relation[index] = self.get_relation(R.special.alldiff(terms), variable_factory)

body_relation_list = jpype.java.util.ArrayList(body_relation)

java_rule.setHead(self.head_atom(head_relation))
Expand Down
40 changes: 0 additions & 40 deletions neuralogic/core/constructs/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class Rule:
__slots__ = "head", "body", "metadata"

def __init__(self, head, body):
from neuralogic.core import Relation

self.head = head
self.metadata: Optional[Metadata] = None

Expand All @@ -56,44 +54,6 @@ def __init__(self, head, body):

self.body = list(body)

if self.is_ellipsis_templated():
variable_set = {term for term in head.terms if term is not Ellipsis and str(term)[0].isupper()}

for body_atom in self.body:
if body_atom.predicate.special and body_atom.predicate.name == "alldiff":
continue

for term in body_atom.terms:
if term is not Ellipsis and str(term)[0].isupper():
variable_set.add(term)

for atom_index, body_atom in enumerate(self.body):
if not body_atom.predicate.special or body_atom.predicate.name != "alldiff":
continue

new_terms = []
found_replacement = False

for index, term in enumerate(body_atom.terms):
if term is Ellipsis:
if found_replacement:
raise NotImplementedError
found_replacement = True
new_terms.extend(variable_set)
else:
new_terms.append(term)
if found_replacement:
self.body[atom_index] = Relation.special.alldiff(*new_terms)

def is_ellipsis_templated(self) -> bool:
for body_atom in self.body:
if not body_atom.predicate.special or body_atom.predicate.name != "alldiff":
continue
for term in body_atom.terms:
if term is Ellipsis:
return True
return False

def to_str(self, _: bool = False) -> str:
return str(self)

Expand Down
Binary file modified neuralogic/jar/NeuraLogic.jar
Binary file not shown.
14 changes: 0 additions & 14 deletions tests/test_constructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,6 @@ def test_rule_metadata():


def test_rules():
my_rule: Rule = R.a(V.X) <= R.special.alldiff(...)

assert len(my_rule.body[0].terms) == 1
assert my_rule.body[0].terms[0] == V.X

my_rule: Rule = R.a(V.X) <= (R.special.alldiff(...), R.b(V.Y, V.Z))
assert len(my_rule.body[0].terms) == 3

terms = sorted(my_rule.body[0].terms)

assert terms[0] == V.X
assert terms[1] == V.Y
assert terms[2] == V.Z

my_rule = R.a <= R.b

assert len(my_rule.body) == 1
Expand Down

0 comments on commit c66a4c0

Please sign in to comment.