From a0e172a631a73f025db4cc85941fac8c108178a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Zahradn=C3=ADk?= Date: Fri, 6 Sep 2024 00:25:40 +0200 Subject: [PATCH] Generate vertex literals --- neuralogic/dataset/logic.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/neuralogic/dataset/logic.py b/neuralogic/dataset/logic.py index 0c4017b..25b1c7e 100644 --- a/neuralogic/dataset/logic.py +++ b/neuralogic/dataset/logic.py @@ -6,6 +6,7 @@ from neuralogic.core.constructs.relation import BaseRelation from neuralogic.core.constructs.rule import Rule from neuralogic.dataset.base import BaseDataset +from neuralogic.core.constructs.factories import R DatasetEntries = Union[BaseRelation, Rule] @@ -105,7 +106,30 @@ def set_queries(self, queries: List): def generate_features(self, feature_depth: int = 1, count_groundings: bool = True): java_factory = JavaFactory() - clause = jpype.java.util.ArrayList([java_factory.to_clause(sample.example) for sample in self.samples]) + + clauses = [] + vertex_lit = R.get("__vert") + vertex_lit.predicate.special = False + vertex_lit.predicate.hidden = False + + for sample in self.samples: + vertex = set() + + for e in sample.example: + if isinstance(e, Rule): + vertex.update(self._get_constants(e.head)) + + for rel in e.body: + vertex.update(self._get_constants(rel)) + if isinstance(e, BaseRelation): + vertex.update(self._get_constants(e)) + + example = [vertex_lit(vert) for vert in vertex] + example.extend(sample.example) + + clauses.append(java_factory.to_clause(example)) + + clause = jpype.java.util.ArrayList(clauses) namespace = "cz.cvut.fel.ida.logic.features.generation" @@ -116,3 +140,6 @@ def generate_features(self, feature_depth: int = 1, count_groundings: bool = Tru clauses = [str(clause) for clause in features.features] return table, clauses + + def _get_constants(self, relation: BaseRelation): + return [term for term in relation.terms if not str(relation)[0].isupper()]