Skip to content

Commit

Permalink
Generate vertex literals
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Sep 5, 2024
1 parent d52c28a commit a0e172a
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion neuralogic/dataset/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from neuralogic.core.constructs.relation import BaseRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset.base import BaseDataset
from neuralogic.core.constructs.factories import R

DatasetEntries = Union[BaseRelation, Rule]

Expand Down Expand Up @@ -105,7 +106,30 @@ def set_queries(self, queries: List):

def generate_features(self, feature_depth: int = 1, count_groundings: bool = True):
java_factory = JavaFactory()
clause = jpype.java.util.ArrayList([java_factory.to_clause(sample.example) for sample in self.samples])

clauses = []
vertex_lit = R.get("__vert")
vertex_lit.predicate.special = False
vertex_lit.predicate.hidden = False

for sample in self.samples:
vertex = set()

for e in sample.example:
if isinstance(e, Rule):
vertex.update(self._get_constants(e.head))

for rel in e.body:
vertex.update(self._get_constants(rel))
if isinstance(e, BaseRelation):
vertex.update(self._get_constants(e))

example = [vertex_lit(vert) for vert in vertex]
example.extend(sample.example)

clauses.append(java_factory.to_clause(example))

clause = jpype.java.util.ArrayList(clauses)

namespace = "cz.cvut.fel.ida.logic.features.generation"

Expand All @@ -116,3 +140,6 @@ def generate_features(self, feature_depth: int = 1, count_groundings: bool = Tru
clauses = [str(clause) for clause in features.features]

return table, clauses

def _get_constants(self, relation: BaseRelation):
return [term for term in relation.terms if not str(relation)[0].isupper()]

0 comments on commit a0e172a

Please sign in to comment.