diff --git a/README.md b/README.md
index 1e5a178c..df76085e 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
[![Tweet](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Fgithub.com%2FLukasZahradnik%2FPyNeuraLogic)](https://twitter.com/intent/tweet?text=Check%20out:&url=https%3A%2F%2Fgithub.com%2FLukasZahradnik%2FPyNeuraLogic)
-[Documentation](https://pyneuralogic.readthedocs.io/en/latest/) | [Examples](#-examples) | [Papers](#-papers)
+[Documentation](https://pyneuralogic.readthedocs.io/en/latest/) ยท [Examples](#-examples) ยท [Papers](#-papers) ยท [Report Bug](https://github.com/LukasZahradnik/PyNeuraLogic/issues/new?assignees=&labels=bug&projects=&template=bug_report.yaml&title=%5B%F0%9F%90%9B+Bug+Report%5D%3A+) ยท [Request Feature](https://github.com/LukasZahradnik/PyNeuraLogic/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yaml&title=%5B%E2%9C%A8+Feature+Request%5D%3A+)
PyNeuraLogic lets you use Python to write **Differentiable Logic Programs**
@@ -36,13 +36,13 @@ Many things! For instance - ever heard of [Graph Neural Networks](https://distil
Or, a bit more 'formally':
```logtalk
-Relation.message2(Var.X) <= (Relation.message1(Var.Y), Relation.edge(Var.Y, Var.X))
+R.msg2(Var.X) <= (R.msg1(V.Y), R.edge(V.Y, V.X))
```
...and that's the actual _code_! Now for a classic learnable GNN layer, you'll want to add some weights, such as
```logtalk
-Relation.message2(Var.X)[5,10] <= (Relation.message1(Var.Y)[10,20], Relation.edge(Var.Y, Var.X))
+R.msg2(Var.X)[5,10] <= (R.msg1(V.Y)[10,20], R.edge(V.Y, V.X))
```
to project your `[20,1]` input node embeddings ('message1') through a learnable ``[10,20]`` layer before the aggregation, and subsequently a `[5,10]` layer after the aggregation.
@@ -50,11 +50,29 @@ to project your `[20,1]` input node embeddings ('message1') through a learnable
If you don't like the default settings, you can of course [specify](https://pyneuralogic.readthedocs.io/en/latest/language.html) various additional details, such as the particular aggregation and activation functions
```logtalk
-(R.message2(V.X)[5,10] <= (R.message1(V.Y)[10,20], R.edge(V.Y, V.X))) | [Transformation.RELU, Aggregation.AVG]
+(R.msg2(V.X)[5,10] <= (R.msg1(V.Y)[10,20], R.edge(V.Y, V.X))) | [Transformation.RELU, Aggregation.AVG]
```
to instantiate the classic GCN layer specification, which you can directly train now!
+```mermaid
+graph TD;
+ edge10[/"edge(1, 0)"\]-->RuleNeuron1("msg2(0) <= msg1(1), edge(1, 0).");
+ msg1[/"msg1(1)"\]-- w_1 -->RuleNeuron1;
+
+ edge00[/"edge(0, 0)"\]-->RuleNeuron2("msg2(0) <= msg1(0), edge(0, 0).");
+ msg0[/"msg1(0)"\]-- w_1 -->RuleNeuron2;
+
+ edge30[/"edge(3, 0)"\]-->RuleNeuron3("msg2(0) <= msg1(3), edge(3, 0).");
+ msg3[/"msg1(3)"\]-- w_1 -->RuleNeuron3;
+
+ RuleNeuron1-- ReLU -->AggregationNeuron[["Rules Aggregation (Average)"]]
+ RuleNeuron2-- ReLU -->AggregationNeuron[["Rules Aggregation (Average)"]]
+ RuleNeuron3-- ReLU -->AggregationNeuron[["Rules Aggregation (Average)"]]
+
+ AggregationNeuron-- w_2 -->OutputNeuron[\"Output Neuron (Tanh)"/]
+
+```
### How is it different from other GNN frameworks?
@@ -85,7 +103,7 @@ We hope you'll find the framework useful in designing _your own_ deep **relation
Please let us know if you need some guidance or would like to cooperate!
-## ๐ก Getting started
+## ๐ Getting started
### Installation
@@ -106,7 +124,20 @@ Python >= 3.8
Java >= 1.8
```
-In case you want to use visualization provided in the library, it is required to have [Graphviz](https://graphviz.org/download/) installed.
+> \[!TIP]
+>
+> In case you want to use visualization provided in the library, it is required to have [Graphviz](https://graphviz.org/download/) installed.
+
+
+
+## ๐ฆ Predefined Modules
+
+PyNeuraLogic has a set of predefined modules to get you quickly started with your experimenting!
+It contains, for example, predefined modules for:
+
+- Graph Neural Networks (GCNConv, SAGEConv, GINConv, RGCNConv, ...)
+- Meta graphs and meta paths (MetaConv, MAGNN, ...)
+- Transformer, LSTM, GRU, RNN, [...and more!](https://pyneuralogic.readthedocs.io/en/latest/zoo.html)
## ๐ฌ Examples
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LukasZahradnik/PyNeuraLogic/blob/master/examples/SimpleXOR.ipynb) [Simple XOR example](https://github.com/LukasZahradnik/PyNeuraLogic/blob/master/examples/SimpleXOR.ipynb)
@@ -124,18 +155,6 @@ In case you want to use visualization provided in the library, it is required to
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb) [Distinguishing non-regular graphs](https://github.com/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb)
-
-
-
-## ๐ฆ Predefined Modules
-
-PyNeuraLogic has a set of predefined modules to get you quickly started with your experimenting!
-It contains, for example, predefined modules for:
-
-- Graph Neural Networks (GNNConv, SAGEConv, GINConv, RGCNConv, ...)
-- Meta graphs and meta paths (MetaConv, MAGNN, ...)
-- Transformer, LSTM, GRU, RNN, [...and more!](https://pyneuralogic.readthedocs.io/en/latest/zoo.html)
-
## ๐ Papers
- [Beyond Graph Neural Networks with Lifted Relational Neural Networks](https://arxiv.org/abs/2007.06286) Machine Learning Journal, 2021
diff --git a/benchmarks/pyneuralogic_benchmark.py b/benchmarks/pyneuralogic_benchmark.py
index 3e0de113..61578261 100644
--- a/benchmarks/pyneuralogic_benchmark.py
+++ b/benchmarks/pyneuralogic_benchmark.py
@@ -20,22 +20,14 @@
def gcn(activation: Transformation, output_size: int, num_features: int, dim: int = 10):
template = Template()
- template += (R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)) | [Transformation.IDENTITY]
- template += R.atom_embed / 1 | [Transformation.IDENTITY]
+ template += R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)
- template += (R.l1_embed(V.X)[dim, dim] <= (R.atom_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
+ template += (R.l1_embed(V.X)[dim, dim] <= (R.atom_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
template += R.l1_embed / 1 | [Transformation.RELU]
- template += (R.l2_embed(V.X)[dim, dim] <= (R.l1_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
- template += R.l2_embed / 1 | [Transformation.IDENTITY]
+ template += (R.l2_embed(V.X)[dim, dim] <= (R.l1_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
- template += (R.predict[output_size, dim] <= R.l2_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
+ template += (R.predict[output_size, dim] <= R.l2_embed(V.X)) | [Aggregation.AVG]
template += R.predict / 0 | [activation]
return template
@@ -44,65 +36,47 @@ def gcn(activation: Transformation, output_size: int, num_features: int, dim: in
def gin(activation: Transformation, output_size: int, num_features: int, dim: int = 10):
template = Template()
- template += (R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)) | [Transformation.IDENTITY]
- template += R.atom_embed / 1 | [Transformation.IDENTITY]
+ template += R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)
template += (R.l1_embed(V.X) <= (R.atom_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM, Transformation.IDENTITY]
- template += (R.l1_embed(V.X) <= R.atom_embed(V.X)) | [Transformation.IDENTITY]
- template += R.l1_embed / 1 | [Transformation.IDENTITY]
+ template += R.l1_embed(V.X) <= R.atom_embed(V.X)
template += (R.l1_mlp_embed(V.X)[dim, dim] <= R.l1_embed(V.X)[dim, dim]) | [Transformation.RELU]
template += R.l1_mlp_embed / 1 | [Transformation.RELU]
# --
- template += (R.l2_embed(V.X) <= (R.l1_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
- template += (R.l2_embed(V.X) <= R.l1_mlp_embed(V.X)) | [Transformation.IDENTITY]
- template += R.l2_embed / 1 | [Transformation.IDENTITY]
+ template += (R.l2_embed(V.X) <= (R.l1_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
+ template += R.l2_embed(V.X) <= R.l1_mlp_embed(V.X)
template += (R.l2_mlp_embed(V.X)[dim, dim] <= R.l2_embed(V.X)[dim, dim]) | [Transformation.RELU]
template += R.l2_mlp_embed / 1 | [Transformation.RELU]
# --
- template += (R.l3_embed(V.X) <= (R.l2_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
- template += (R.l3_embed(V.X) <= R.l2_mlp_embed(V.X)) | [Transformation.IDENTITY]
- template += R.l3_embed / 1 | [Transformation.IDENTITY]
+ template += (R.l3_embed(V.X) <= (R.l2_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
+ template += R.l3_embed(V.X) <= R.l2_mlp_embed(V.X)
template += (R.l3_mlp_embed(V.X)[dim, dim] <= R.l3_embed(V.X)[dim, dim]) | [Transformation.RELU]
template += R.l3_mlp_embed / 1 | [Transformation.RELU]
# --
- template += (R.l4_embed(V.X) <= (R.l3_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
- template += (R.l4_embed(V.X) <= R.l3_mlp_embed(V.X)) | [Transformation.IDENTITY]
- template += R.l4_embed / 1 | [Transformation.IDENTITY]
+ template += (R.l4_embed(V.X) <= (R.l3_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
+ template += R.l4_embed(V.X) <= R.l3_mlp_embed(V.X)
template += (R.l4_mlp_embed(V.X)[dim, dim] <= R.l4_embed(V.X)[dim, dim]) | [Transformation.RELU]
template += R.l4_mlp_embed / 1 | [Transformation.RELU]
# --
- template += (R.l5_embed(V.X) <= (R.l4_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.SUM,
- Transformation.IDENTITY,
- ]
- template += (R.l5_embed(V.X) <= R.l4_mlp_embed(V.X)) | [Transformation.IDENTITY]
- template += R.l5_embed / 1 | [Transformation.IDENTITY]
+ template += (R.l5_embed(V.X) <= (R.l4_mlp_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.SUM]
+ template += R.l5_embed(V.X) <= R.l4_mlp_embed(V.X)
template += (R.l5_mlp_embed(V.X)[dim, dim] <= R.l5_embed(V.X)[dim, dim]) | [Transformation.RELU]
template += R.l5_mlp_embed / 1 | [Transformation.RELU]
- template += (R.predict[output_size, dim] <= R.l1_mlp_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
- template += (R.predict[output_size, dim] <= R.l2_mlp_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
- template += (R.predict[output_size, dim] <= R.l3_mlp_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
- template += (R.predict[output_size, dim] <= R.l4_mlp_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
- template += (R.predict[output_size, dim] <= R.l5_mlp_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
+ template += (R.predict[output_size, dim] <= R.l1_mlp_embed(V.X)) | [Aggregation.AVG]
+ template += (R.predict[output_size, dim] <= R.l2_mlp_embed(V.X)) | [Aggregation.AVG]
+ template += (R.predict[output_size, dim] <= R.l3_mlp_embed(V.X)) | [Aggregation.AVG]
+ template += (R.predict[output_size, dim] <= R.l4_mlp_embed(V.X)) | [Aggregation.AVG]
+ template += (R.predict[output_size, dim] <= R.l5_mlp_embed(V.X)) | [Aggregation.AVG]
template += R.predict / 0 | [activation]
@@ -112,24 +86,16 @@ def gin(activation: Transformation, output_size: int, num_features: int, dim: in
def gsage(activation: Transformation, output_size: int, num_features: int, dim: int = 10):
template = Template()
- template += (R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)) | [Transformation.IDENTITY]
- template += R.atom_embed / 1 | [Transformation.IDENTITY]
+ template += R.atom_embed(V.X)[dim, num_features] <= R.node_feature(V.X)
- template += (R.l1_embed(V.X)[dim, dim] <= R.atom_embed(V.X)) | [Transformation.IDENTITY]
- template += (R.l1_embed(V.X)[dim, dim] <= (R.atom_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.AVG,
- Transformation.IDENTITY,
- ]
+ template += R.l1_embed(V.X)[dim, dim] <= R.atom_embed(V.X)
+ template += R.l1_embed(V.X)[dim, dim] <= (R.atom_embed(V.Y), R._edge(V.Y, V.X)) | [Aggregation.AVG]
template += R.l1_embed / 1 | [Transformation.RELU]
- template += (R.l2_embed(V.X)[dim, dim] <= R.l1_embed(V.X)) | [Transformation.IDENTITY]
- template += (R.l2_embed(V.X)[dim, dim] <= (R.l1_embed(V.Y), R._edge(V.Y, V.X))) | [
- Aggregation.AVG,
- Transformation.IDENTITY,
- ]
- template += R.l2_embed / 1 | [Transformation.IDENTITY]
+ template += R.l2_embed(V.X)[dim, dim] <= R.l1_embed(V.X)
+ template += (R.l2_embed(V.X)[dim, dim] <= (R.l1_embed(V.Y), R._edge(V.Y, V.X))) | [Aggregation.AVG]
- template += (R.predict[output_size, dim] <= R.l2_embed(V.X)) | [Aggregation.AVG, Transformation.IDENTITY]
+ template += (R.predict[output_size, dim] <= R.l2_embed(V.X)) | [Aggregation.AVG]
template += R.predict / 0 | [activation]
return template
diff --git a/examples/datasets/horses.py b/examples/datasets/horses.py
index b2edb15e..44654fe4 100644
--- a/examples/datasets/horses.py
+++ b/examples/datasets/horses.py
@@ -1,4 +1,4 @@
-from neuralogic.core import Relation, Template, Var, Const
+from neuralogic.core import Relation, Template, Var, Const, Transformation
from neuralogic.dataset import Dataset
@@ -9,9 +9,11 @@
template.add_rules(
[
- Relation.foal(Var.X)[1, ] <= (Relation.parent(Var.X, Var.Y), Relation.horse(Var.Y)), # todo gusta: mozna prejmenovat Atom -> Predicate by odpovidalo skutecnosti prirozeneji?
- Relation.foal(Var.X)[1, ] <= (Relation.sibling(Var.X, Var.Y), Relation.horse(Var.Y)),
- Relation.negFoal(Var.X)[1, ] <= Relation.foal(Var.X),
+ (Relation.foal(Var.X)[1, ] <= (Relation.parent(Var.X, Var.Y), Relation.horse(Var.Y))) | [Transformation.TANH],
+ (Relation.foal(Var.X)[1, ] <= (Relation.sibling(Var.X, Var.Y), Relation.horse(Var.Y))) | [Transformation.TANH],
+ (Relation.negFoal(Var.X)[1, ] <= Relation.foal(Var.X)) | [Transformation.TANH],
+ Relation.foal / 1 | [Transformation.TANH],
+ Relation.negFoal / 1 | [Transformation.TANH],
]
)
diff --git a/examples/datasets/multiple_examples_no_order_trains.py b/examples/datasets/multiple_examples_no_order_trains.py
index cf4a707d..8d9ab3d7 100644
--- a/examples/datasets/multiple_examples_no_order_trains.py
+++ b/examples/datasets/multiple_examples_no_order_trains.py
@@ -1,7 +1,7 @@
from typing import List
from examples.datasets.data.train_example_data import train_example_data
-from neuralogic.core import Relation, Template, Var, Const
+from neuralogic.core import Relation, Template, Var, Const, Transformation
from neuralogic.dataset import Dataset
@@ -20,18 +20,30 @@
Y = Var.Y # todo gusta: tohle je dobry trik, ten bych pouzival na vic mistech, a podobne pro Atom/Predicate factories udelat zkratky (treba P.)
+meta = [Transformation.TANH]
+
template.add_rules(
[
- *[Relation.shape(Y) <= Relation.shape(Y, s)[1, ] for s in shapes],
- *[Relation.length(Y) <= Relation.length(Y, s)[1, ] for s in [Const.short, Const.long]],
- *[Relation.sides(Y) <= Relation.sides(Y, s)[1, ] for s in [Const.not_double, Const.double]],
- *[Relation.roof(Y) <= Relation.roof(Y, s)[1, ] for s in roofs],
- *[Relation.wheels(Y) <= Relation.wheels(Y, s)[1, ] for s in [2, 3]],
- *[Relation.loadnum(Y) <= Relation.loadnum(Y, s)[1, ] for s in [0, 1, 2, 3]],
- *[Relation.loadshape(Y) <= Relation.loadshape(Y, s)[1, ] for s in loadshapes],
- Relation.vagon(Y) <= (atom(Y)[1, ] for atom in vagon_atoms),
- Relation.train <= Relation.vagon(Y)[1, ],
- Relation.direction <= Relation.train[1, ],
+ *[(Relation.shape(Y) <= Relation.shape(Y, s)[1, ]) | meta for s in shapes],
+ *[(Relation.length(Y) <= Relation.length(Y, s)[1, ]) | meta for s in [Const.short, Const.long]],
+ *[(Relation.sides(Y) <= Relation.sides(Y, s)[1, ]) | meta for s in [Const.not_double, Const.double]],
+ *[(Relation.roof(Y) <= Relation.roof(Y, s)[1, ]) | meta for s in roofs],
+ *[(Relation.wheels(Y) <= Relation.wheels(Y, s)[1, ]) | meta for s in [2, 3]],
+ *[(Relation.loadnum(Y) <= Relation.loadnum(Y, s)[1, ]) | meta for s in [0, 1, 2, 3]],
+ *[(Relation.loadshape(Y) <= Relation.loadshape(Y, s)[1, ]) | meta for s in loadshapes],
+ (Relation.vagon(Y) <= (atom(Y)[1, ] for atom in vagon_atoms)) | meta,
+ (Relation.train <= Relation.vagon(Y)[1, ]) | meta,
+ (Relation.direction <= Relation.train[1, ]) | meta,
+ Relation.shape / 1 | meta,
+ Relation.length / 1 | meta,
+ Relation.sides / 1 | meta,
+ Relation.roof / 1 | meta,
+ Relation.wheels / 1 | meta,
+ Relation.loadnum / 1 | meta,
+ Relation.loadshape / 1 | meta,
+ Relation.vagon / 1 | meta,
+ Relation.train / 0 | meta,
+ Relation.direction / 0 | meta,
]
)
diff --git a/examples/datasets/multiple_examples_trains.py b/examples/datasets/multiple_examples_trains.py
index 7be9813e..ea0b152c 100644
--- a/examples/datasets/multiple_examples_trains.py
+++ b/examples/datasets/multiple_examples_trains.py
@@ -1,7 +1,7 @@
from typing import List
from examples.datasets.data.train_example_data import train_example_data
-from neuralogic.core import Relation, Template, Var, Const
+from neuralogic.core import Relation, Template, Var, Const, Transformation
from neuralogic.dataset import Dataset
@@ -19,18 +19,30 @@
Y = Var.Y
+meta = [Transformation.TANH]
+
template.add_rules(
[
- *[Relation.shape(Y) <= Relation.shape(Y, s)[1, ] for s in shapes],
- *[Relation.length(Y) <= Relation.length(Y, s)[1, ] for s in [Const.short, Const.long]],
- *[Relation.sides(Y) <= Relation.sides(Y, s)[1, ] for s in [Const.not_double, Const.double]],
- *[Relation.roof(Y) <= Relation.roof(Y, s)[1, ] for s in roofs],
- *[Relation.wheels(Y) <= Relation.wheels(Y, s)[1, ] for s in [2, 3]],
- *[Relation.loadnum(Y) <= Relation.loadnum(Y, s)[1, ] for s in [0, 1, 2, 3]],
- *[Relation.loadshape(Y) <= Relation.loadshape(Y, s)[1, ] for s in loadshapes],
- Relation.vagon(Y) <= (atom(Y)[1, ] for atom in vagon_atoms),
- *[Relation.train <= Relation.vagon(i)[1, ] for i in [1, 2, 3, 4]],
- Relation.direction <= Relation.train[1, ],
+ *[(Relation.shape(Y) <= Relation.shape(Y, s)[1, ]) | meta for s in shapes],
+ *[(Relation.length(Y) <= Relation.length(Y, s)[1, ]) | meta for s in [Const.short, Const.long]],
+ *[(Relation.sides(Y) <= Relation.sides(Y, s)[1, ]) | meta for s in [Const.not_double, Const.double]],
+ *[(Relation.roof(Y) <= Relation.roof(Y, s)[1, ]) | meta for s in roofs],
+ *[(Relation.wheels(Y) <= Relation.wheels(Y, s)[1, ]) | meta for s in [2, 3]],
+ *[(Relation.loadnum(Y) <= Relation.loadnum(Y, s)[1, ]) | meta for s in [0, 1, 2, 3]],
+ *[(Relation.loadshape(Y) <= Relation.loadshape(Y, s)[1, ]) | meta for s in loadshapes],
+ (Relation.vagon(Y) <= (atom(Y)[1, ] for atom in vagon_atoms)) | meta,
+ *[(Relation.train <= Relation.vagon(i)[1, ]) | meta for i in [1, 2, 3, 4]],
+ (Relation.direction <= Relation.train[1, ]) | meta,
+ Relation.shape / 1 | meta,
+ Relation.length / 1 | meta,
+ Relation.sides / 1 | meta,
+ Relation.roof / 1 | meta,
+ Relation.wheels / 1 | meta,
+ Relation.loadnum / 1 | meta,
+ Relation.loadshape / 1 | meta,
+ Relation.vagon / 1 | meta,
+ Relation.train / 0 | meta,
+ Relation.direction / 0 | meta,
]
)
diff --git a/examples/datasets/naive_trains.py b/examples/datasets/naive_trains.py
index ce72d7e7..e1aa8241 100644
--- a/examples/datasets/naive_trains.py
+++ b/examples/datasets/naive_trains.py
@@ -1,6 +1,6 @@
from examples.datasets.data.train_example_data import train_example_data
-from neuralogic.core import Relation, Template, Var, Const
+from neuralogic.core import Relation, Template, Var, Const, Transformation
from neuralogic.dataset import Dataset
@@ -19,18 +19,30 @@
X = Var.X
Y = Var.Y
+meta = [Transformation.TANH]
+
template.add_rules(
[
- *[Relation.shape(X, Y) <= Relation.shape(X, Y, s)[1, ] for s in shapes],
- *[Relation.length(X, Y) <= Relation.length(X, Y, s)[1, ] for s in [Const.short, Const.long]],
- *[Relation.sides(X, Y) <= Relation.sides(X, Y, s)[1, ] for s in [Const.not_double, Const.double]],
- *[Relation.roof(X, Y) <= Relation.roof(X, Y, s)[1, ] for s in roofs],
- *[Relation.wheels(X, Y) <= Relation.wheels(X, Y, s)[1, ] for s in [2, 3]],
- *[Relation.loadnum(X, Y) <= Relation.loadnum(X, Y, s)[1, ] for s in [0, 1, 2, 3]],
- *[Relation.loadshape(X, Y) <= Relation.loadshape(X, Y, s)[1, ] for s in loadshapes],
- Relation.vagon(X, Y) <= (atom(X, Y)[1, ] for atom in vagon_atoms),
- *[Relation.train(X) <= Relation.vagon(X, i)[1, ] for i in [1, 2, 3, 4]],
- Relation.direction(X) <= Relation.train(X)[1, ],
+ *[(Relation.shape(X, Y) <= Relation.shape(X, Y, s)[1, ]) | meta for s in shapes],
+ *[(Relation.length(X, Y) <= Relation.length(X, Y, s)[1, ]) | meta for s in [Const.short, Const.long]],
+ *[(Relation.sides(X, Y) <= Relation.sides(X, Y, s)[1, ]) | meta for s in [Const.not_double, Const.double]],
+ *[(Relation.roof(X, Y) <= Relation.roof(X, Y, s)[1, ]) | meta for s in roofs],
+ *[(Relation.wheels(X, Y) <= Relation.wheels(X, Y, s)[1, ]) | meta for s in [2, 3]],
+ *[(Relation.loadnum(X, Y) <= Relation.loadnum(X, Y, s)[1, ]) | meta for s in [0, 1, 2, 3]],
+ *[(Relation.loadshape(X, Y) <= Relation.loadshape(X, Y, s)[1, ]) | meta for s in loadshapes],
+ (Relation.vagon(X, Y) <= (atom(X, Y)[1, ] for atom in vagon_atoms)) | meta,
+ *[(Relation.train(X) <= Relation.vagon(X, i)[1, ]) | meta for i in [1, 2, 3, 4]],
+ (Relation.direction(X) <= Relation.train(X)[1, ]) | meta,
+ Relation.shape / 2 | meta,
+ Relation.length / 2 | meta,
+ Relation.sides / 2 | meta,
+ Relation.roof / 2 | meta,
+ Relation.wheels / 2 | meta,
+ Relation.loadnum / 2 | meta,
+ Relation.loadshape / 2 | meta,
+ Relation.vagon / 2 | meta,
+ Relation.train / 1 | meta,
+ Relation.direction / 1 | meta,
]
)
diff --git a/examples/datasets/naive_xor.py b/examples/datasets/naive_xor.py
index a6001cc1..a592a2d2 100644
--- a/examples/datasets/naive_xor.py
+++ b/examples/datasets/naive_xor.py
@@ -1,4 +1,4 @@
-from neuralogic.core import Relation, Template
+from neuralogic.core import Relation, Template, Transformation
from neuralogic.dataset import Dataset, Sample
@@ -8,10 +8,12 @@
# fmt: off
# hidden<1-8> :- {1} a, {1} b.
-template.add_rules([Relation.get(f"hidden{i}") <= (Relation.a[1, ], Relation.b[1, ]) for i in range(1, 9)])
+template.add_rules([(Relation.get(f"hidden{i}") <= (Relation.a[1, ], Relation.b[1, ])) | [Transformation.TANH] for i in range(1, 9)])
+template.add_rules([Relation.get(f"hidden{i}") / 0 | [Transformation.TANH] for i in range(1, 9)])
# {1} xor :- hidden<1-8>.
-template.add_rules([Relation.xor[1, ] <= Relation.get(f"hidden{i}") for i in range(1, 9)])
+template.add_rules([(Relation.xor[1, ] <= Relation.get(f"hidden{i}")) | [Transformation.TANH] for i in range(1, 9)])
+template.add_rules([Relation.xor / 0 | [Transformation.TANH] for i in range(1, 9)])
dataset.add_samples(
[ # Add 4 examples
diff --git a/examples/datasets/vectorized_xor.py b/examples/datasets/vectorized_xor.py
index f26986c4..1a811cb4 100644
--- a/examples/datasets/vectorized_xor.py
+++ b/examples/datasets/vectorized_xor.py
@@ -1,4 +1,4 @@
-from neuralogic.core import Relation, Template
+from neuralogic.core import Relation, Template, Transformation
from neuralogic.dataset import Dataset, Sample
@@ -7,7 +7,8 @@
template = Template()
-template.add_rule(Relation.xor[1, 8] <= Relation.xy[8, 2]) # Add template rule
+template.add_rule((Relation.xor[1, 8] <= Relation.xy[8, 2]) | [Transformation.TANH]) # Add template rule
+template.add_rule(Relation.xor / 0 | [Transformation.TANH])
dataset.add_samples(
[ # Add 4 examples
diff --git a/neuralogic/core/__init__.py b/neuralogic/core/__init__.py
index e9117524..9f4687d5 100644
--- a/neuralogic/core/__init__.py
+++ b/neuralogic/core/__init__.py
@@ -5,7 +5,7 @@
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.settings import Settings, SettingsProxy
from neuralogic.core.enums import Grounder
-from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
+from neuralogic.core.constructs.function import Transformation, Aggregation, Combination, F
__all__ = [
@@ -15,6 +15,7 @@
"C",
"Relation",
"R",
+ "F",
"Rule",
"RuleBody",
"Template",
diff --git a/neuralogic/core/constructs/function/__init__.py b/neuralogic/core/constructs/function/__init__.py
index 35e6af01..a7855ff2 100644
--- a/neuralogic/core/constructs/function/__init__.py
+++ b/neuralogic/core/constructs/function/__init__.py
@@ -1,32 +1,5 @@
-from neuralogic.core.constructs.function.concat import ConcatComb, Concat
-from neuralogic.core.constructs.function.function import Transformation, Combination, Aggregation, Function
-from neuralogic.core.constructs.function.reshape import Reshape
-from neuralogic.core.constructs.function.slice import Slice
-from neuralogic.core.constructs.function.softmax import Softmax
+from neuralogic.core.constructs.function.enum import Transformation, Combination, Aggregation, F
+from neuralogic.core.constructs.function.function import Function
+from neuralogic.core.constructs.function.function_container import FContainer
-_special_namings = {"LEAKY_RELU": "LEAKYRELU", "TRANSP": "TRANSPOSE"}
-
-for function_name in Transformation.__annotations__:
- setattr(Transformation, function_name, Transformation(_special_namings.get(function_name, function_name)))
-
-
-Transformation.SLICE = Slice("SLICE")
-Transformation.RESHAPE = Reshape("RESHAPE")
-
-
-for function_name in Combination.__annotations__:
- setattr(Combination, function_name, Combination(function_name))
-
-
-Combination.CONCAT = ConcatComb("CONCAT")
-
-
-for function_name in Aggregation.__annotations__:
- setattr(Aggregation, function_name, Aggregation(function_name))
-
-
-Aggregation.CONCAT = Concat("CONCAT")
-Aggregation.SOFTMAX = Softmax("SOFTMAX")
-
-
-__all__ = ["Transformation", "Combination", "Aggregation", "Function"]
+__all__ = ["Transformation", "Combination", "Aggregation", "F", "Function", "FContainer"]
diff --git a/neuralogic/core/constructs/function/concat.py b/neuralogic/core/constructs/function/concat.py
index a5e85f04..9fa1ec7e 100644
--- a/neuralogic/core/constructs/function/concat.py
+++ b/neuralogic/core/constructs/function/concat.py
@@ -1,9 +1,9 @@
import jpype
-from neuralogic.core.constructs.function.function import Aggregation, Combination
+from neuralogic.core.constructs.function.function import AggregationFunction, CombinationFunction
-class ConcatComb(Combination):
+class ConcatCombination(CombinationFunction):
__slots__ = ("axis",)
def __init__(
@@ -15,9 +15,9 @@ def __init__(
super().__init__(name)
self.axis = axis
- def __call__(self, entity=None, *, axis: int = -1):
- concat = ConcatComb(self.name, axis=axis)
- return Combination.__call__(concat, entity)
+ def __call__(self, *relations, axis: int = -1):
+ concat = ConcatCombination(self.name, axis=axis)
+ return CombinationFunction.__call__(concat, *relations)
def is_parametrized(self) -> bool:
return self.axis != -1
@@ -31,7 +31,7 @@ def __str__(self):
return f"concat(axis={self.axis})"
-class Concat(Aggregation):
+class ConcatAggregation(AggregationFunction):
__slots__ = ("axis",)
def __init__(
@@ -43,9 +43,9 @@ def __init__(
super().__init__(name)
self.axis = axis
- def __call__(self, entity=None, *, axis: int = -1):
- concat = Concat(self.name, axis=axis)
- return Aggregation.__call__(concat, entity)
+ def __call__(self, *, axis: int = -1):
+ concat = ConcatAggregation(self.name, axis=axis)
+ return AggregationFunction.__call__(concat)
def is_parametrized(self) -> bool:
return self.axis != -1
diff --git a/neuralogic/core/constructs/function/enum.py b/neuralogic/core/constructs/function/enum.py
new file mode 100644
index 00000000..9244f091
--- /dev/null
+++ b/neuralogic/core/constructs/function/enum.py
@@ -0,0 +1,117 @@
+from neuralogic.core.constructs.function.concat import ConcatAggregation, ConcatCombination
+from neuralogic.core.constructs.function.function import (
+ TransformationFunction,
+ CombinationFunction,
+ AggregationFunction,
+)
+from neuralogic.core.constructs.function.reshape import Reshape
+from neuralogic.core.constructs.function.slice import Slice
+from neuralogic.core.constructs.function.softmax import SoftmaxAggregation
+
+
+class Transformation:
+ # Element wise
+ SIGMOID: TransformationFunction = TransformationFunction("SIGMOID")
+ TANH: TransformationFunction = TransformationFunction("TANH")
+ SIGNUM: TransformationFunction = TransformationFunction("SIGNUM")
+ RELU: TransformationFunction = TransformationFunction("RELU", namespace="transformation.elementwise.ReLu")
+ LEAKY_RELU: TransformationFunction = TransformationFunction(
+ "LEAKYRELU", namespace="transformation.elementwise.LeakyReLu"
+ )
+ LUKASIEWICZ: TransformationFunction = TransformationFunction(
+ "LUKASIEWICZ", namespace="transformation.elementwise.LukasiewiczSigmoid"
+ )
+ EXP: TransformationFunction = TransformationFunction("EXP", namespace="transformation.elementwise.Exponentiation")
+ SQRT: TransformationFunction = TransformationFunction("SQRT", namespace="transformation.elementwise.SquareRoot")
+ INVERSE: TransformationFunction = TransformationFunction("INVERSE")
+ REVERSE: TransformationFunction = TransformationFunction("REVERSE")
+ LOG: TransformationFunction = TransformationFunction("LOG", namespace="transformation.elementwise.Logarithm")
+
+ # Transformation
+ IDENTITY: TransformationFunction = TransformationFunction("IDENTITY", namespace="transformation.join.{name}")
+ TRANSP: TransformationFunction = TransformationFunction("TRANSPOSE", namespace="transformation.join.Transposition")
+ SOFTMAX: TransformationFunction = TransformationFunction("SOFTMAX", namespace="transformation.join.{name}")
+ SPARSEMAX: TransformationFunction = TransformationFunction("SPARSEMAX", namespace="transformation.join.{name}")
+ NORM: TransformationFunction = TransformationFunction("NORM", namespace="transformation.join.Normalization")
+ SLICE: Slice = Slice("SLICE")
+ RESHAPE: Reshape = Reshape("RESHAPE")
+
+
+class Combination:
+ # Aggregation
+ AVG: CombinationFunction = CombinationFunction("AVG", namespace="aggregation.Average")
+ MAX: CombinationFunction = CombinationFunction("MAX", can_flatten=True, namespace="aggregation.Maximum")
+ MIN: CombinationFunction = CombinationFunction("MIN", can_flatten=True, namespace="aggregation.Minimum")
+ SUM: CombinationFunction = CombinationFunction(
+ "SUM", operator="+", can_flatten=True, namespace="aggregation.{name}"
+ )
+ COUNT: CombinationFunction = CombinationFunction("COUNT", namespace="aggregation.{name}")
+
+ # Combination
+ PRODUCT: CombinationFunction = CombinationFunction("PRODUCT", operator="@")
+ ELPRODUCT: CombinationFunction = CombinationFunction(
+ "ELPRODUCT", operator="*", can_flatten=True, namespace="combination.ElementProduct"
+ )
+ SOFTMAX: CombinationFunction = CombinationFunction("SOFTMAX")
+ SPARSEMAX: CombinationFunction = CombinationFunction("SPARSEMAX")
+ CROSSSUM: CombinationFunction = CombinationFunction("CROSSSUM")
+ CONCAT: ConcatCombination = ConcatCombination("CONCAT")
+ COSSIM: CombinationFunction = CombinationFunction("COSSIM")
+
+
+class Aggregation:
+ AVG: AggregationFunction = AggregationFunction("AVG")
+ MAX: AggregationFunction = AggregationFunction("MAX")
+ MIN: AggregationFunction = AggregationFunction("MIN")
+ SUM: AggregationFunction = AggregationFunction("SUM")
+ COUNT: AggregationFunction = AggregationFunction("COUNT")
+ CONCAT: ConcatAggregation = ConcatAggregation("CONCAT")
+ SOFTMAX: SoftmaxAggregation = SoftmaxAggregation("SOFTMAX")
+
+
+class F:
+ # Element wise
+ sigmoid: TransformationFunction = Transformation.SIGMOID
+ tanh: TransformationFunction = Transformation.TANH
+ signum: TransformationFunction = Transformation.SIGNUM
+ relu: TransformationFunction = Transformation.RELU
+ leaky_relu: TransformationFunction = Transformation.LEAKY_RELU
+ lukasiewicz: TransformationFunction = Transformation.LUKASIEWICZ
+ exp: TransformationFunction = Transformation.EXP
+ sqrt: TransformationFunction = Transformation.SQRT
+ inverse: TransformationFunction = Transformation.INVERSE
+ reverse: TransformationFunction = Transformation.REVERSE
+ log: TransformationFunction = Transformation.LOG
+
+ # Transformation
+ identity: TransformationFunction = Transformation.IDENTITY
+ transp: TransformationFunction = Transformation.TRANSP
+ softmax: TransformationFunction = Transformation.SOFTMAX
+ sparsemax: TransformationFunction = Transformation.SPARSEMAX
+ norm: TransformationFunction = Transformation.NORM
+ slice: Slice = Transformation.SLICE
+ reshape: Reshape = Transformation.RESHAPE
+
+ # Combination
+ avg: CombinationFunction = Combination.AVG
+ max: CombinationFunction = Combination.MAX
+ min: CombinationFunction = Combination.MIN
+ sum: CombinationFunction = Combination.SUM
+ count: CombinationFunction = Combination.COUNT
+
+ product: CombinationFunction = Combination.PRODUCT
+ elproduct: CombinationFunction = Combination.ELPRODUCT
+ softmax_comb: CombinationFunction = Combination.SOFTMAX
+ sparsemax_comb: CombinationFunction = Combination.SPARSEMAX
+ crossum: CombinationFunction = Combination.CROSSSUM
+ concat: ConcatCombination = Combination.CONCAT
+ cossim: CombinationFunction = Combination.COSSIM
+
+ # Aggregation
+ avg_agg: AggregationFunction = Aggregation.AVG
+ max_agg: AggregationFunction = Aggregation.MAX
+ min_agg: AggregationFunction = Aggregation.MIN
+ sum_agg: AggregationFunction = Aggregation.SUM
+ count_agg: AggregationFunction = Aggregation.COUNT
+ concat_agg: ConcatAggregation = Aggregation.CONCAT
+ softmax_agg: SoftmaxAggregation = Aggregation.SOFTMAX
diff --git a/neuralogic/core/constructs/function/function.py b/neuralogic/core/constructs/function/function.py
index 0a2eaf21..f3c36b6c 100644
--- a/neuralogic/core/constructs/function/function.py
+++ b/neuralogic/core/constructs/function/function.py
@@ -1,8 +1,16 @@
+from typing import Optional
+
+import jpype
+
+
class Function:
- __slots__ = ("name",)
+ __slots__ = "name", "operator", "can_flatten", "namespace"
- def __init__(self, name: str):
+ def __init__(self, name: str, *, namespace: str = "", operator: Optional[str] = None, can_flatten: bool = False):
self.name: str = name.lower()
+ self.operator: Optional[str] = operator
+ self.can_flatten = can_flatten
+ self.namespace = namespace
def __str__(self):
return self.name
@@ -13,82 +21,71 @@ def wrap(self, content: str) -> str:
def pretty_str(self) -> str:
return str(self).capitalize()
- def __call__(self, *args, **kwargs):
- if len(args) == 0 or args[0] is None:
+ def __call__(self, *args):
+ if len(args) == 0:
return self
raise NotImplementedError
def is_parametrized(self) -> bool:
return False
- def get(self):
- raise NotImplementedError
-
def rule_head_dependant(self) -> bool:
return False
def process_head(self, head) -> "Function":
pass
+ def get(self):
+ name = "".join(s.capitalize() for s in self.name.split("_"))
+ formatted_namespace = self.namespace.format(name=name)
+
+ return jpype.JClass(f"cz.cvut.fel.ida.algebra.functions.{formatted_namespace}")()
+
+
+class TransformationFunction(Function):
+ def __init__(
+ self,
+ name: str,
+ *,
+ namespace: str = "transformation.elementwise.{name}",
+ operator: Optional[str] = None,
+ can_flatten: bool = False,
+ ):
+ super().__init__(name, namespace=namespace, operator=operator, can_flatten=can_flatten)
-class Transformation(Function):
- # Element wise
- SIGMOID: "Transformation"
- TANH: "Transformation"
- SIGNUM: "Transformation"
- RELU: "Transformation"
- LEAKY_RELU: "Transformation"
- LUKASIEWICZ: "Transformation"
- EXP: "Transformation"
- SQRT: "Transformation"
- INVERSE: "Transformation"
- REVERSE: "Transformation"
- LOG: "Transformation"
-
- # Transformation
- IDENTITY: "Transformation"
- TRANSP: "Transformation"
- SOFTMAX: "Transformation"
- SPARSEMAX: "Transformation"
- NORM: "Transformation"
- SLICE: "Transformation"
- RESHAPE: "Transformation"
-
- def __call__(self, *args, **kwargs):
- from neuralogic.core.constructs import relation
-
- if len(args) == 0 or args[0] is None:
+ def __call__(self, relation: Optional = None, **kwargs):
+ from neuralogic.core.constructs import relation as rel
+ from neuralogic.core.constructs.function.function_container import FContainer
+
+ if relation is None:
return self
- arg = args[0]
- if isinstance(arg, relation.BaseRelation):
- return arg.attach_activation_function(self)
- raise NotImplementedError
+ if isinstance(relation, rel.BaseRelation) and not isinstance(relation, rel.WeightedRelation):
+ if relation.negated or relation.function is not None:
+ return FContainer((relation,), self)
+ return relation.attach_activation_function(self)
+ return FContainer(relation, self)
+
+class CombinationFunction(Function):
+ def __init__(
+ self,
+ name: str,
+ *,
+ namespace: str = "combination.{name}",
+ operator: Optional[str] = None,
+ can_flatten: bool = False,
+ ):
+ super().__init__(name, namespace=namespace, operator=operator, can_flatten=can_flatten)
-class Combination(Function):
- # Aggregation
- AVG: "Combination"
- MAX: "Combination"
- MIN: "Combination"
- SUM: "Combination"
- COUNT: "Combination"
-
- # Combination
- PRODUCT: "Combination"
- ELPRODUCT: "Combination"
- SOFTMAX: "Combination"
- SPARSEMAX: "Combination"
- CROSSSUM: "Combination"
- CONCAT: "Combination"
- COSSIM: "Combination"
-
-
-class Aggregation(Function):
- AVG: "Aggregation"
- MAX: "Aggregation"
- MIN: "Aggregation"
- SUM: "Aggregation"
- COUNT: "Aggregation"
- CONCAT: "Aggregation"
- SOFTMAX: "Aggregation"
+ def __call__(self, *relations):
+ from neuralogic.core.constructs.function.function_container import FContainer
+
+ if len(relations) == 0:
+ return self
+ return FContainer(relations, self)
+
+
+class AggregationFunction(Function):
+ def get(self):
+ raise NotImplementedError
diff --git a/neuralogic/core/constructs/function/function_container.py b/neuralogic/core/constructs/function/function_container.py
new file mode 100644
index 00000000..359f8c9c
--- /dev/null
+++ b/neuralogic/core/constructs/function/function_container.py
@@ -0,0 +1,114 @@
+from typing import Dict
+
+import jpype
+
+from neuralogic.core.constructs.function.enum import Combination
+from neuralogic.core.constructs.function.function_graph import FunctionGraph
+from neuralogic.core.constructs.function.function import Function
+
+
+class FContainer:
+ __slots__ = "nodes", "function"
+
+ def __init__(self, nodes, function: Function):
+ self.function = function
+ self.nodes = nodes if not self.function.can_flatten else self.get_flattened_nodes(nodes, function)
+
+ @staticmethod
+ def get_flattened_nodes(nodes, function: Function):
+ new_nodes = []
+ for node in nodes:
+ if not isinstance(node, FContainer):
+ new_nodes.append(node)
+ continue
+
+ if node.function.name == function.name:
+ new_nodes.extend(node.nodes)
+ else:
+ new_nodes.append(node)
+ return tuple(new_nodes)
+
+ def __add__(self, other):
+ return FContainer((self, other), Combination.SUM)
+
+ def __mul__(self, other):
+ return FContainer((self, other), Combination.ELPRODUCT)
+
+ def __matmul__(self, other):
+ return FContainer((self, other), Combination.PRODUCT)
+
+ def __str__(self):
+ if self.function.operator is not None:
+ return f" {self.function.operator} ".join(
+ node.to_str(True) if isinstance(node, FContainer) else node.to_str() for node in self.nodes
+ )
+
+ args = ", ".join(node.to_str() for node in self.nodes)
+
+ if args:
+ return f"{self.function}({args})"
+ return f"{self.function}"
+
+ @property
+ def name(self):
+ args = ", ".join(str(node.function) for node in self.nodes if isinstance(node, FContainer))
+
+ if args:
+ return f"{self.function}({args})"
+ return f"{self.function}"
+
+ def __iter__(self):
+ for node in self.nodes:
+ if isinstance(node, FContainer):
+ for a in node:
+ yield a
+ else:
+ yield node
+
+ def to_function(self) -> Function:
+ graph = self._get_function_node({}, 0)
+ return FunctionGraph(name=self.name, function_graph=graph)
+
+ def _get_function_node(self, input_counter: Dict[int, int], start_index: int = 0):
+ from neuralogic.core.constructs.relation import BaseRelation
+
+ next_indices = [-1] * len(self.nodes)
+ next_nodes = [None] * len(self.nodes)
+
+ for i, node in enumerate(self.nodes):
+ if isinstance(node, FContainer):
+ next_node = node._get_function_node(input_counter)
+ if next_node is None:
+ continue
+ next_nodes[i] = next_node
+ elif isinstance(node, BaseRelation):
+ idx = id(node)
+
+ if node.predicate.hidden or node.predicate.special or node.predicate.name.startswith("_"):
+ continue
+
+ if idx not in input_counter:
+ input_counter[idx] = len(input_counter) + start_index
+ next_indices[i] = input_counter[idx]
+ else:
+ raise ValueError(f"{node} of type {type(node)} inside of body function is not supported")
+
+ filtered_next_node = []
+ filtered_next_indices = []
+
+ for i, (node, index) in enumerate(zip(next_nodes, next_indices)):
+ if node is not None or index != -1:
+ filtered_next_node.append(node)
+ filtered_next_indices.append(index)
+
+ if not filtered_next_node or not filtered_next_indices:
+ return None
+
+ class_name = "cz.cvut.fel.ida.algebra.functions.combination.FunctionGraph.FunctionGraphNode"
+
+ return jpype.JClass(class_name)(self.function.get(), filtered_next_node, filtered_next_indices)
+
+ def to_str(self, parentheses_wrap: bool = False):
+ if parentheses_wrap and self.function.operator is not None:
+ return f"({self})"
+ return self.__str__()
diff --git a/neuralogic/core/constructs/function/function_graph.py b/neuralogic/core/constructs/function/function_graph.py
new file mode 100644
index 00000000..fea5d9a6
--- /dev/null
+++ b/neuralogic/core/constructs/function/function_graph.py
@@ -0,0 +1,27 @@
+import jpype
+
+from neuralogic.core.constructs.function.function import Function
+
+
+class FunctionGraph(Function):
+ __slots__ = ("function_graph",)
+
+ def __init__(
+ self,
+ name: str,
+ *,
+ function_graph,
+ ):
+ super().__init__(name)
+ self.function_graph = function_graph
+
+ def is_parametrized(self) -> bool:
+ return True
+
+ def get(self):
+ return jpype.JClass("cz.cvut.fel.ida.algebra.functions.combination.FunctionGraph")(
+ self.name, self.function_graph
+ )
+
+ def __str__(self):
+ return self.name
diff --git a/neuralogic/core/constructs/function/reshape.py b/neuralogic/core/constructs/function/reshape.py
index 949d4981..d14a2e2a 100644
--- a/neuralogic/core/constructs/function/reshape.py
+++ b/neuralogic/core/constructs/function/reshape.py
@@ -1,11 +1,11 @@
-from typing import Union, Tuple
+from typing import Union, Tuple, Optional
import jpype
-from neuralogic.core.constructs.function.function import Transformation
+from neuralogic.core.constructs.function.function import TransformationFunction
-class Reshape(Transformation):
+class Reshape(TransformationFunction):
__slots__ = ("shape",)
def __init__(
@@ -22,12 +22,12 @@ def __init__(
def __call__(
self,
- entity=None,
+ relation: Optional = None,
*,
- shape: Union[None, Tuple[int, int], int],
+ shape: Union[None, Tuple[int, int], int] = None,
):
reshape = Reshape(self.name, shape=shape)
- return Transformation.__call__(reshape, entity)
+ return TransformationFunction.__call__(reshape, relation)
def is_parametrized(self) -> bool:
return True
diff --git a/neuralogic/core/constructs/function/slice.py b/neuralogic/core/constructs/function/slice.py
index b33d8143..33692102 100644
--- a/neuralogic/core/constructs/function/slice.py
+++ b/neuralogic/core/constructs/function/slice.py
@@ -1,11 +1,11 @@
-from typing import Union, Tuple
+from typing import Union, Tuple, Optional
import jpype
-from neuralogic.core.constructs.function.function import Transformation
+from neuralogic.core.constructs.function.function import TransformationFunction
-class Slice(Transformation):
+class Slice(TransformationFunction):
__slots__ = ("rows", "cols")
def __init__(
@@ -28,13 +28,13 @@ def __init__(
def __call__(
self,
- entity=None,
+ relation: Optional = None,
*,
rows: Union[type(Ellipsis), Tuple[int, int]] = ...,
cols: Union[type(Ellipsis), Tuple[int, int]] = ...,
):
slice = Slice(self.name, rows=rows, cols=cols)
- return Transformation.__call__(slice, entity)
+ return TransformationFunction.__call__(slice, relation)
def is_parametrized(self) -> bool:
return True
diff --git a/neuralogic/core/constructs/function/softmax.py b/neuralogic/core/constructs/function/softmax.py
index 570e487e..85d09b7d 100644
--- a/neuralogic/core/constructs/function/softmax.py
+++ b/neuralogic/core/constructs/function/softmax.py
@@ -2,10 +2,10 @@
import jpype
-from neuralogic.core.constructs.function.function import Aggregation
+from neuralogic.core.constructs.function.function import AggregationFunction
-class Softmax(Aggregation):
+class SoftmaxAggregation(AggregationFunction):
__slots__ = ("agg_terms", "var_terms")
def __init__(
@@ -18,9 +18,9 @@ def __init__(
self.term_indices = agg_terms
self.agg_terms = agg_terms
- def __call__(self, entity=None, *, agg_terms: Sequence[int] = None):
- softmax = Softmax(self.name, agg_terms=agg_terms)
- return Aggregation.__call__(softmax, entity)
+ def __call__(self, *, agg_terms: Sequence[int] = None):
+ softmax = SoftmaxAggregation(self.name, agg_terms=agg_terms)
+ return AggregationFunction.__call__(softmax)
def is_parametrized(self) -> bool:
return self.agg_terms is not None
@@ -36,7 +36,7 @@ def __str__(self):
def rule_head_dependant(self) -> bool:
return self.agg_terms is not None
- def process_head(self, head) -> "Softmax":
+ def process_head(self, head) -> "SoftmaxAggregation":
term_indices = []
for agg_term in set(self.agg_terms):
@@ -48,7 +48,7 @@ def process_head(self, head) -> "Softmax":
term_indices.append(i)
break
- aggregation = Softmax(self.name, agg_terms=self.agg_terms)
+ aggregation = SoftmaxAggregation(self.name, agg_terms=self.agg_terms)
aggregation.term_indices = term_indices
return aggregation
diff --git a/neuralogic/core/constructs/java_objects.py b/neuralogic/core/constructs/java_objects.py
index ac5a7467..a9ca18fd 100644
--- a/neuralogic/core/constructs/java_objects.py
+++ b/neuralogic/core/constructs/java_objects.py
@@ -6,6 +6,10 @@
from neuralogic import is_initialized, initialize
from neuralogic.core.constructs.factories import R
+from neuralogic.core.constructs.function.enum import Combination
+from neuralogic.core.constructs.function.function import CombinationFunction
+from neuralogic.core.constructs.metadata import Metadata
+from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.term import Variable, Constant
from neuralogic.core.settings import SettingsProxy, Settings
@@ -70,6 +74,21 @@ def get_value(self, weight):
raise ValueError(f"Cannot create weight from type {type(weight)}, value {weight}")
+def _is_body_flat(body: FContainer):
+ if not isinstance(body.function, CombinationFunction):
+ return False
+
+ for node in body.nodes:
+ if isinstance(node, FContainer):
+ return False
+ return True
+
+
+def _flatten_rule_body(body, metadata: Metadata):
+ combination = Combination.SUM if metadata is None or metadata.combination is None else metadata.combination
+ return combination(*body)
+
+
class JavaFactory:
def __init__(self, settings: Optional[SettingsProxy] = None):
from neuralogic.core.constructs.rule import Rule
@@ -318,13 +337,26 @@ def get_rule(self, rule):
else:
java_rule.setWeight(weight)
+ body = rule.body
+ if not isinstance(body, FContainer) and rule._contains_function_container():
+ body = _flatten_rule_body(body, rule.metadata)
+
+ contains_refs = False
all_variables = {term for term in rule.head.terms if term is not Ellipsis and str(term)[0].isupper()}
body_relation = []
all_diff_index = []
+ processed_relations = {}
+ is_fcontainer = isinstance(body, FContainer)
for i, relation in enumerate(rule.body):
all_variables.update(term for term in relation.terms if term is not Ellipsis and str(term)[0].isupper())
+ if is_fcontainer:
+ if id(relation) in processed_relations:
+ contains_refs = True
+ continue
+ processed_relations[id(relation)] = True
+
if relation.predicate.special and relation.predicate.name == "alldiff":
found = False
@@ -358,7 +390,17 @@ def get_rule(self, rule):
if rule.metadata is not None:
java_rule.allowDuplicitGroundings = bool(rule.metadata.duplicit_grounding)
- java_rule.setMetadata(self.get_metadata(rule.metadata, self.rule_metadata))
+ metadata = rule.metadata
+ if isinstance(body, FContainer):
+ metadata = metadata.copy() if metadata is not None else Metadata()
+
+ if not contains_refs and _is_body_flat(body):
+ metadata.combination = body.function
+ body = body.nodes
+ else:
+ metadata.combination = body.to_function()
+
+ java_rule.setMetadata(self.get_metadata(metadata, self.rule_metadata))
return java_rule
diff --git a/neuralogic/core/constructs/metadata.py b/neuralogic/core/constructs/metadata.py
index 5b965245..8e5408b9 100644
--- a/neuralogic/core/constructs/metadata.py
+++ b/neuralogic/core/constructs/metadata.py
@@ -1,6 +1,11 @@
from typing import Union, Iterable, Callable, Optional
-from neuralogic.core.constructs.function import Transformation, Combination, Aggregation, Function
+from neuralogic.core.constructs.function.function import (
+ Function,
+ AggregationFunction,
+ TransformationFunction,
+ CombinationFunction,
+)
class Metadata:
@@ -9,9 +14,9 @@ class Metadata:
def __init__(
self,
learnable: bool = None,
- transformation: Union[str, Transformation, Combination] = None,
- combination: Union[str, Combination] = None,
- aggregation: Union[str, Aggregation] = None,
+ transformation: Union[str, TransformationFunction, CombinationFunction] = None,
+ combination: Union[str, CombinationFunction] = None,
+ aggregation: Union[str, AggregationFunction] = None,
duplicit_grounding: Optional[bool] = None,
):
self.learnable = learnable
@@ -27,11 +32,11 @@ def from_iterable(iterable: Iterable) -> "Metadata":
for entry in iterable:
if isinstance(entry, Callable) and not isinstance(entry, Function):
entry = entry()
- if isinstance(entry, Aggregation):
+ if isinstance(entry, AggregationFunction):
metadata.aggregation = entry
- elif isinstance(entry, Transformation):
+ elif isinstance(entry, TransformationFunction):
metadata.transformation = entry
- elif isinstance(entry, Combination):
+ elif isinstance(entry, CombinationFunction):
metadata.combination = entry
else:
raise ValueError(f"Invalid entry for metadata: {entry}")
diff --git a/neuralogic/core/constructs/relation.py b/neuralogic/core/constructs/relation.py
index eb214fba..bdf01635 100644
--- a/neuralogic/core/constructs/relation.py
+++ b/neuralogic/core/constructs/relation.py
@@ -2,9 +2,11 @@
import numpy as np
+from neuralogic.core.constructs.function.enum import Combination, Transformation
+from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.predicate import Predicate
from neuralogic.core.constructs import rule, factories
-from neuralogic.core.constructs.function import Transformation, Combination
+from neuralogic.core.constructs.function.function import TransformationFunction, CombinationFunction
class BaseRelation:
@@ -14,7 +16,7 @@ def __init__(
self,
predicate: Predicate,
terms=None,
- function: Union[Transformation, Combination] = None,
+ function: Union[TransformationFunction, CombinationFunction] = None,
negated: bool = False,
):
self.predicate = predicate
@@ -50,7 +52,7 @@ def __invert__(self) -> "BaseRelation":
def T(self) -> "BaseRelation":
return self.attach_activation_function(Transformation.TRANSP)
- def attach_activation_function(self, function: Union[Transformation, Combination]):
+ def attach_activation_function(self, function: Union[TransformationFunction, CombinationFunction]):
if self.negated:
raise ValueError(f"Cannot attach function to negated relation {self}")
relation = self.__copy__()
@@ -126,12 +128,26 @@ def __and__(self, other) -> rule.RuleBody:
return rule.RuleBody(self, other)
raise NotImplementedError
+ def __add__(self, other):
+ return FContainer((self, other), Combination.SUM)
+
+ def __mul__(self, other):
+ return FContainer((self, other), Combination.ELPRODUCT)
+
+ def __matmul__(self, other):
+ return FContainer((self, other), Combination.PRODUCT)
+
class WeightedRelation(BaseRelation):
__slots__ = "weight", "weight_name", "is_fixed"
def __init__(
- self, weight, predicate: Predicate, fixed=False, terms=None, function: Union[Transformation, Combination] = None
+ self,
+ weight,
+ predicate: Predicate,
+ fixed=False,
+ terms=None,
+ function: Union[TransformationFunction, CombinationFunction] = None,
):
super().__init__(predicate, terms, function, False)
diff --git a/neuralogic/core/constructs/rule.py b/neuralogic/core/constructs/rule.py
index cd559694..60dc20b7 100644
--- a/neuralogic/core/constructs/rule.py
+++ b/neuralogic/core/constructs/rule.py
@@ -1,5 +1,6 @@
from typing import Iterable, Optional
+from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.metadata import Metadata
@@ -13,7 +14,7 @@ def __init__(self, lit1, lit2):
def __and__(self, other):
from neuralogic.core.constructs.relation import BaseRelation
- if isinstance(other, BaseRelation):
+ if isinstance(other, (BaseRelation, FContainer)):
self.literals.append(other)
return self
raise NotImplementedError
@@ -52,13 +53,24 @@ def __init__(self, head, body):
if not isinstance(body, Iterable):
body = [body]
- self.body = list(body)
+ self.body = body
+
+ if not isinstance(self.body, FContainer):
+ self.body = list(self.body)
+
+ def _contains_function_container(self):
+ for lit in self.body:
+ if isinstance(lit, FContainer):
+ return True
+ return False
def to_str(self, _: bool = False) -> str:
return str(self)
def __str__(self) -> str:
metadata = "" if self.metadata is None is None else f" {self.metadata}"
+ if isinstance(self.body, FContainer):
+ return f"{self.head.to_str()} :- {self.body.to_str()}.{metadata}"
return f"{self.head.to_str()} :- {', '.join(atom.to_str() for atom in self.body)}.{metadata}"
def __repr__(self) -> str:
diff --git a/neuralogic/core/settings/__init__.py b/neuralogic/core/settings/__init__.py
index 9f558c5f..41f6ef86 100644
--- a/neuralogic/core/settings/__init__.py
+++ b/neuralogic/core/settings/__init__.py
@@ -5,7 +5,6 @@
from neuralogic.nn.init import Initializer, Uniform
from neuralogic.nn.loss import MSE, ErrorFunction
from neuralogic.core.settings.settings_proxy import SettingsProxy
-from neuralogic.core.constructs.function import Transformation, Combination, Aggregation
from neuralogic.optim import Optimizer, Adam
@@ -18,11 +17,6 @@ def __init__(
epochs: int = 3000,
error_function: ErrorFunction = MSE(),
initializer: Initializer = Uniform(),
- rule_transformation: Transformation = Transformation.TANH,
- rule_combination: Combination = Combination.SUM,
- rule_aggregation: Aggregation = Aggregation.AVG,
- relation_transformation: Transformation = Transformation.TANH,
- relation_combination: Combination = Combination.SUM,
iso_value_compression: bool = True,
chain_pruning: bool = True,
prune_only_identities: bool = False,
@@ -98,46 +92,6 @@ def initializer(self) -> Initializer:
def initializer(self, initializer: Initializer):
self._update("initializer", initializer)
- @property
- def relation_transformation(self) -> Transformation:
- return self.params["relation_transformation"]
-
- @relation_transformation.setter
- def relation_transformation(self, value: Transformation):
- self._update("relation_transformation", value)
-
- @property
- def relation_combination(self) -> Combination:
- return self.params["relation_combination"]
-
- @relation_combination.setter
- def relation_combination(self, value: Combination):
- self._update("relation_combination", value)
-
- @property
- def rule_transformation(self) -> Transformation:
- return self.params["rule_transformation"]
-
- @rule_transformation.setter
- def rule_transformation(self, value: Transformation):
- self._update("rule_transformation", value)
-
- @property
- def rule_combination(self) -> Combination:
- return self.params["rule_combination"]
-
- @rule_combination.setter
- def rule_combination(self, value: Combination):
- self._update("rule_combination", value)
-
- @property
- def rule_aggregation(self) -> Aggregation:
- return self.params["rule_aggregation"]
-
- @rule_aggregation.setter
- def rule_aggregation(self, value: Aggregation):
- self._update("rule_aggregation", value)
-
def create_proxy(self) -> SettingsProxy:
proxy = SettingsProxy(**self.params)
self._proxies.add(proxy)
diff --git a/neuralogic/core/settings/settings_proxy.py b/neuralogic/core/settings/settings_proxy.py
index a462190e..765f5f68 100644
--- a/neuralogic/core/settings/settings_proxy.py
+++ b/neuralogic/core/settings/settings_proxy.py
@@ -20,11 +20,6 @@ def __init__(
epochs: int,
error_function: ErrorFunction,
initializer: Initializer,
- rule_transformation: Transformation,
- rule_combination: Combination,
- rule_aggregation: Aggregation,
- relation_transformation: Transformation,
- relation_combination: Combination,
iso_value_compression: bool,
chain_pruning: bool,
prune_only_identities: bool,
@@ -44,6 +39,12 @@ def __init__(
for key, value in params.items():
self.__setattr__(key, value)
+ self.rule_transformation = Transformation.IDENTITY
+ self.relation_transformation = Transformation.IDENTITY
+ self.rule_combination = Combination.SUM
+ self.relation_combination = Combination.SUM
+ self.rule_aggregation = Aggregation.AVG
+
self.settings.debugExporting = False
self.settings.exportBlocks = []
diff --git a/neuralogic/jar/NeuraLogic.jar b/neuralogic/jar/NeuraLogic.jar
index 67476ae3..ab984d66 100644
Binary files a/neuralogic/jar/NeuraLogic.jar and b/neuralogic/jar/NeuraLogic.jar differ
diff --git a/neuralogic/nn/functional.py b/neuralogic/nn/functional.py
deleted file mode 100644
index 17bef508..00000000
--- a/neuralogic/nn/functional.py
+++ /dev/null
@@ -1,209 +0,0 @@
-from typing import Union, Tuple, Sequence
-
-from neuralogic.core.constructs.relation import BaseRelation
-from neuralogic.core.constructs.function import Transformation, Combination, Function, Aggregation
-
-
-dot_type = type(Ellipsis)
-
-
-# Transformation
-
-
-def sigmoid(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.SIGMOID(entity)
-
-
-def tanh(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.TANH(entity)
-
-
-def signum(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.SIGNUM(entity)
-
-
-def relu(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.RELU(entity)
-
-
-def leaky_relu(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.LEAKY_RELU(entity)
-
-
-def lukasiewicz(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.LUKASIEWICZ(entity)
-
-
-def exp(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.EXP(entity)
-
-
-def sqrt(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.SQRT(entity)
-
-
-def inverse(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.INVERSE(entity)
-
-
-def reverse(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.REVERSE(entity)
-
-
-def log(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.LOG(entity)
-
-
-def identity(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.IDENTITY(entity)
-
-
-def transp(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.TRANSP(entity)
-
-
-def softmax(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.SOFTMAX(entity)
-
-
-def sparsemax(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.SPARSEMAX(entity)
-
-
-def norm(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Transformation.NORM(entity)
-
-
-def slice(
- entity: Union[BaseRelation, Function] = None,
- *,
- rows: Union[type(Ellipsis), Tuple[int, int]] = ...,
- cols: Union[type(Ellipsis), Tuple[int, int]] = ...,
-) -> Union[BaseRelation, Function]:
- r"""
- Slices a value into a new value that is created by taking values on specified rows and columns.
-
- Rows and Cols coordinates are specified either as ``...``, which means all rows/cols or by a tuple of two
- elements ``[from_index, to_index]``.
-
- Parameters
- ----------
-
- entity : Union[BaseRelation, Function]
- Relation to apply the function on. Default: ``None``
- rows : Union[type(Ellipsis), Tuple[int, int]]
- Default: ``...``
- cols : Union[type(Ellipsis), Tuple[int, int]]
- Default: ``...``
- """
- return Transformation.SLICE(entity, rows=rows, cols=cols)
-
-
-def reshape(
- entity: Union[BaseRelation, Function] = None,
- *,
- shape: Union[None, Tuple[int, int], int],
-) -> Union[BaseRelation, Function]:
- r"""
- Change the shape/type of the value to a new shape. The shape can be either ``None``, int, or a tuple of two ints.
-
- * If ``None``, the underlying value will be converted to a scalar. E.g., a matrix value of one element ``[[1]]``
- will be converted to scalar ``1``.
-
- * If int, then the value will be converted to scalar (if the int is ``0``) or to a column vector.
-
- * If a tuple of two ints, the value will be converted to a scalar if the tuple is ``(0, 0)``. Into a row vector
- if the shape is ``(len, 0)`` or a column vector for shape ``(0, len)``. For other tuples ``(n, m)``,
- the value will be reshaped to matrix :math:`n \times m`.
-
- Parameters
- ----------
-
- entity : Union[BaseRelation, Function]
- Relation to apply the function on. Default: ``None``
- shape : Union[None, Tuple[int, int], int]
- The new shape of the value
- """
- return Transformation.RESHAPE(entity, shape=shape)
-
-
-# Combination
-
-
-def max_comb(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Combination.MAX(entity)
-
-
-def min_comb(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Combination.MIN(entity)
-
-
-def avg_comb(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Combination.AVG(entity)
-
-
-def sum_comb(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Combination.SUM(entity)
-
-
-def count_comb(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Combination.COUNT(entity)
-
-
-def product_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.PRODUCT(entity)
-
-
-def elproduct_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.ELPRODUCT(entity)
-
-
-def softmax_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.SOFTMAX(entity)
-
-
-def sparsemax_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.SPARSEMAX(entity)
-
-
-def crosssum_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.CROSSSUM(entity)
-
-
-def concat_comb(entity: Union[BaseRelation, Function] = None, *, axis: int = -1) -> Union[BaseRelation, Function]:
- return Combination.CONCAT(entity, axis=axis)
-
-
-def cossim_comb(entity: Union[BaseRelation, Function] = None) -> Union[BaseRelation, Function]:
- return Combination.COSSIM(entity)
-
-
-# Aggregations
-
-
-def max(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Aggregation.MAX(entity)
-
-
-def min(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Aggregation.MIN(entity)
-
-
-def avg(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Aggregation.AVG(entity)
-
-
-def sum(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Aggregation.SUM(entity)
-
-
-def count(entity: BaseRelation = None) -> Union[BaseRelation, Function]:
- return Aggregation.COUNT(entity)
-
-
-def concat(entity: BaseRelation = None, *, axis: int = -1) -> Union[BaseRelation, Function]:
- return Aggregation.CONCAT(entity, axis=axis)
-
-
-def softmax_agg(entity: BaseRelation = None, *, agg_terms: Sequence[str] = None) -> Union[BaseRelation, Function]:
- return Aggregation.SOFTMAX(entity, agg_terms=agg_terms)
diff --git a/neuralogic/nn/module/general/attention.py b/neuralogic/nn/module/general/attention.py
index d7eba8f0..35202d33 100644
--- a/neuralogic/nn/module/general/attention.py
+++ b/neuralogic/nn/module/general/attention.py
@@ -59,12 +59,11 @@ def __call__(self):
dk_rel = R.get(f"{self.output_name}__dk")
dot_rel = R.get(f"{self.output_name}__dot")
- metadata = [Combination.PRODUCT, Transformation.IDENTITY, Aggregation.SOFTMAX(agg_terms=["Y"])]
- out_metadata = [Combination.PRODUCT, Aggregation.SUM, Transformation.IDENTITY]
+ metadata = [Combination.PRODUCT, Aggregation.SOFTMAX(agg_terms=["Y"])]
+ out_metadata = [Combination.PRODUCT, Aggregation.SUM]
attention_product_rules = [
(dot_rel(h_terms) <= (dk_rel, R.get(self.key_name)(k_terms).T, R.get(self.query_name)(q_terms))) | metadata,
- dot_rel / (self.arity + 1) | [Transformation.IDENTITY],
]
if self.mask_name is not None:
@@ -74,7 +73,6 @@ def __call__(self):
dk_rel[d_k].fixed(),
*attention_product_rules,
(R.get(self.output_name)(q_terms) <= (dot_rel(h_terms), R.get(self.value_name)(k_terms))) | out_metadata,
- R.get(self.output_name) / self.arity | [Transformation.IDENTITY],
]
@@ -159,12 +157,7 @@ def __call__(self):
attention.arity += 1
attention_concat = []
- multihead_rules = [
- q_proj / (self.arity + 1) | [Transformation.IDENTITY],
- k_proj / (self.arity + 1) | [Transformation.IDENTITY],
- v_proj / (self.arity + 1) | [Transformation.IDENTITY],
- output_rel / self.arity | [Transformation.IDENTITY],
- ]
+ multihead_rules = []
for i in range(self.num_heads):
meta = [Transformation.SLICE(rows=(i * size, (i + 1) * size))]
@@ -173,19 +166,13 @@ def __call__(self):
multihead_rules.append((k_proj(i, *terms) <= R.get(self.keys)(terms)[k_weight:dim, self.kdim]) | meta)
attention_concat.append(R.get(attention_name)(i, *terms))
- multihead_rules.append(
- (output_rel(terms)[dim, dim] <= attention_concat) | [Transformation.IDENTITY, Combination.CONCAT]
- )
+ multihead_rules.append((output_rel(terms)[dim, dim] <= attention_concat) | [Combination.CONCAT])
else:
multihead_rules = [
- (q_proj(terms)[q_weight:dim, dim] <= R.get(self.queries)(terms)) | [Transformation.IDENTITY],
- q_proj / self.arity | [Transformation.IDENTITY],
- (v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.values)(terms)) | [Transformation.IDENTITY],
- v_proj / self.arity | [Transformation.IDENTITY],
- (k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.keys)(terms)) | [Transformation.IDENTITY],
- k_proj / self.arity | [Transformation.IDENTITY],
- (output_rel(terms)[dim, dim] <= R.get(attention_name)(terms)) | [Transformation.IDENTITY],
- output_rel / self.arity | [Transformation.IDENTITY],
+ q_proj(terms)[q_weight:dim, dim] <= R.get(self.queries)(terms),
+ v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.values)(terms),
+ k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.keys)(terms),
+ output_rel(terms)[dim, dim] <= R.get(attention_name)(terms),
]
return [*attention(), *multihead_rules]
diff --git a/neuralogic/nn/module/general/gru.py b/neuralogic/nn/module/general/gru.py
index 86b015d0..09d839e1 100644
--- a/neuralogic/nn/module/general/gru.py
+++ b/neuralogic/nn/module/general/gru.py
@@ -98,16 +98,11 @@ def __call__(self):
return [
*r(),
*z(),
- n_helper | [Transformation.IDENTITY, Combination.ELPRODUCT],
- n_helper.head.predicate | [Transformation.IDENTITY],
+ n_helper | [Combination.ELPRODUCT],
n | [Transformation.TANH],
- n.head.predicate | [Transformation.IDENTITY],
- h_left | [Transformation.IDENTITY, Combination.ELPRODUCT],
- h_left.head.predicate | [Transformation.IDENTITY],
- h_right | [Transformation.IDENTITY, Combination.ELPRODUCT],
- h_right.head.predicate | [Transformation.IDENTITY],
- h | [Transformation.IDENTITY],
- h.head.predicate | [Transformation.IDENTITY],
+ h_left | [Combination.ELPRODUCT],
+ h_right | [Combination.ELPRODUCT],
+ h,
]
@@ -253,6 +248,6 @@ def __call__(self):
terms = [f"X{i}" for i in range(self.arity)]
return [
- (R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms)) | [Transformation.IDENTITY],
+ R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms),
*recursive_cell(),
]
diff --git a/neuralogic/nn/module/general/linear.py b/neuralogic/nn/module/general/linear.py
index d6fdd285..4edd1742 100644
--- a/neuralogic/nn/module/general/linear.py
+++ b/neuralogic/nn/module/general/linear.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation
from neuralogic.core.constructs.factories import R
@@ -49,7 +50,7 @@ class Linear(Module):
Output (head) predicate name of the module.
input_name : str
Input name.
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
arity : int
@@ -62,7 +63,7 @@ def __init__(
out_channels: int,
output_name: str,
input_name: str,
- activation: Transformation = Transformation.IDENTITY,
+ activation: TransformationFunction = Transformation.IDENTITY,
arity: int = 1,
):
self.output_name = output_name
@@ -79,6 +80,6 @@ def __call__(self):
head = R.get(self.output_name)(terms)[self.out_channels, self.in_channels]
return [
- (head <= R.get(self.input_name)(terms)) | [Transformation.IDENTITY],
+ head <= R.get(self.input_name)(terms),
R.get(self.output_name) / len(terms) | Metadata(transformation=self.activation),
]
diff --git a/neuralogic/nn/module/general/lstm.py b/neuralogic/nn/module/general/lstm.py
index b1cf5a8a..3f50f5a6 100644
--- a/neuralogic/nn/module/general/lstm.py
+++ b/neuralogic/nn/module/general/lstm.py
@@ -94,15 +94,11 @@ def __call__(self):
*f(),
*o(),
*g(),
- c_left | [Transformation.IDENTITY, Combination.ELPRODUCT],
- c_right | [Transformation.IDENTITY, Combination.ELPRODUCT],
- c_left.head.predicate | [Transformation.IDENTITY],
- c_right.head.predicate | [Transformation.IDENTITY],
- c | [Transformation.IDENTITY],
- (R.get(c_name)([*terms, 0]) <= R.get(self.cell_state_0_name)(terms)) | [Transformation.IDENTITY],
- c.head.predicate | [Transformation.IDENTITY],
- h | [Transformation.IDENTITY, Combination.ELPRODUCT],
- h.head.predicate | [Transformation.IDENTITY],
+ c_left | [Combination.ELPRODUCT],
+ c_right | [Combination.ELPRODUCT],
+ c,
+ R.get(c_name)([*terms, 0]) <= R.get(self.cell_state_0_name)(terms),
+ h | [Combination.ELPRODUCT],
]
@@ -188,6 +184,6 @@ def __call__(self):
terms = [f"X{i}" for i in range(self.arity)]
return [
- (R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms)) | [Transformation.IDENTITY],
+ R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms),
*recursive_cell(),
]
diff --git a/neuralogic/nn/module/general/mlp.py b/neuralogic/nn/module/general/mlp.py
index e7a8e89e..7fc912e0 100644
--- a/neuralogic/nn/module/general/mlp.py
+++ b/neuralogic/nn/module/general/mlp.py
@@ -1,5 +1,6 @@
from typing import List, Union, Sequence
+from neuralogic.core.constructs.function.function import TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation
from neuralogic.core.constructs.factories import R
@@ -18,7 +19,7 @@ class MLP(Module):
Output (head) predicate name of the module.
input_name : str
Input name.
- activation : Union[Transformation, List[Transformation]]
+ activation : Union[TransformationFunction, List[TransformationFunction]]
Activation function of all layers or list of activations for each layer.
Default: ``Transformation.RELU``
arity : int
@@ -30,14 +31,14 @@ def __init__(
units: List[int],
output_name: str,
input_name: str,
- activation: Union[Transformation, List[Transformation]] = Transformation.RELU,
+ activation: Union[TransformationFunction, List[TransformationFunction]] = Transformation.RELU,
arity: int = 1,
):
self.output_name = output_name
self.input_name = input_name
self.units = units
- self.activation: Union[Transformation, List[Transformation]] = activation
+ self.activation: Union[TransformationFunction, List[TransformationFunction]] = activation
self.arity = arity
def __call__(self):
@@ -49,7 +50,7 @@ def __call__(self):
if isinstance(self.activation, Sequence):
metadata = [Metadata(transformation=act) for act in self.activation]
- metadata.extend([Metadata(transformation=Transformation.IDENTITY)] * (iters - len(metadata)))
+ metadata.extend([Metadata()] * (iters - len(metadata)))
else:
metadata = [Metadata(transformation=self.activation)] * (iters + 1)
@@ -67,7 +68,7 @@ def __call__(self):
if index < len(self.activation):
body_metadata = Metadata(transformation=self.activation[index])
else:
- body_metadata = [Transformation.IDENTITY]
+ body_metadata = []
if index + 2 < len(self.units):
in_channels, out_channels = self.units[index + 1], self.units[index + 2]
diff --git a/neuralogic/nn/module/general/pooling.py b/neuralogic/nn/module/general/pooling.py
index 79904b34..f2a9e376 100644
--- a/neuralogic/nn/module/general/pooling.py
+++ b/neuralogic/nn/module/general/pooling.py
@@ -1,5 +1,6 @@
+from neuralogic.core.constructs.function.function import AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
-from neuralogic.core.constructs.function import Transformation, Aggregation
+from neuralogic.core.constructs.function import Aggregation
from neuralogic.core.constructs.factories import R
from neuralogic.nn.module.module import Module
@@ -48,7 +49,7 @@ class Pooling(Module):
Output (head) predicate name of the module.
input_name : str
Input name.
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function.
input_arity : int
Arity of the input predicate ``input_name``. Default: ``1``
@@ -58,7 +59,7 @@ def __init__(
self,
output_name: str,
input_name: str,
- aggregation: Aggregation,
+ aggregation: AggregationFunction,
input_arity: int = 1,
):
self.output_name = output_name
@@ -68,11 +69,10 @@ def __init__(
self.aggregation = aggregation
def __call__(self):
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
return [
(R.get(self.output_name) <= R.get(self.input_name)(f"X{i}" for i in range(self.input_arity))) | metadata,
- R.get(self.output_name) / 0 | [Transformation.IDENTITY],
]
diff --git a/neuralogic/nn/module/general/positional_encoding.py b/neuralogic/nn/module/general/positional_encoding.py
index 8c8dd535..4b636213 100644
--- a/neuralogic/nn/module/general/positional_encoding.py
+++ b/neuralogic/nn/module/general/positional_encoding.py
@@ -2,7 +2,7 @@
import numpy as np
-from neuralogic.core.constructs.function import Transformation, Combination
+from neuralogic.core.constructs.function import Combination
from neuralogic.core.constructs.factories import R
from neuralogic.nn.module.module import Module
@@ -48,9 +48,6 @@ def __call__(self):
else:
rules = [pe_rel(*terms, i)[row].fixed() for i, row in enumerate(pe)]
- rules.append(
- (out_rel(all_terms) <= (pe_rel(all_terms), in_rel(all_terms))) | [Transformation.IDENTITY, Combination.SUM]
- )
+ rules.append((out_rel(all_terms) <= (pe_rel(all_terms), in_rel(all_terms))) | [Combination.SUM])
- rules.append(out_rel / self.arity | [Transformation.IDENTITY])
return rules
diff --git a/neuralogic/nn/module/general/rnn.py b/neuralogic/nn/module/general/rnn.py
index a2e9c465..24515c3b 100644
--- a/neuralogic/nn/module/general/rnn.py
+++ b/neuralogic/nn/module/general/rnn.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation
from neuralogic.core.constructs.factories import R, V
@@ -20,7 +21,7 @@ class RNNCell(Module):
Input feature predicate name to get features from.
hidden_input_name : str
Predicate name to get hidden state from.
- activation : Transformation
+ activation : TransformationFunction
Activation function.
Default: ``Transformation.TANH``
arity : int
@@ -34,7 +35,7 @@ def __init__(
output_name: str,
input_name: str,
hidden_input_name: str,
- activation: Transformation = Transformation.TANH,
+ activation: TransformationFunction = Transformation.TANH,
arity: int = 1,
):
self.input_size = input_size
@@ -60,7 +61,6 @@ def __call__(self):
return [
rnn_rule | Metadata(transformation=self.activation),
- output / (self.arity + 1) | [Transformation.IDENTITY],
]
@@ -106,7 +106,7 @@ class RNN(Module):
Input feature predicate name to get features from.
hidden_0_name : str
Predicate name to get initial hidden state from.
- activation : Transformation
+ activation : TransformationFunction
Activation function.
Default: ``Transformation.TANH``
arity : int
@@ -120,7 +120,7 @@ def __init__(
output_name: str,
input_name: str,
hidden_0_name: str,
- activation: Transformation = Transformation.TANH,
+ activation: TransformationFunction = Transformation.TANH,
arity: int = 1,
):
self.input_size = input_size
@@ -147,6 +147,6 @@ def __call__(self):
terms = [f"X{i}" for i in range(self.arity)]
return [
- (R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms)) | [Transformation.IDENTITY],
+ R.get(self.output_name)([*terms, 0]) <= R.get(self.hidden_0_name)(terms),
*recursive_cell(),
]
diff --git a/neuralogic/nn/module/general/rvnn.py b/neuralogic/nn/module/general/rvnn.py
index 519d4ee6..d6ebcb65 100644
--- a/neuralogic/nn/module/general/rvnn.py
+++ b/neuralogic/nn/module/general/rvnn.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -30,10 +31,10 @@ class RvNN(Module):
max_children : int
Maximum number of children (specify which -ary tree will be considered).
Default: ``2``
- activation : Transformation
+ activation : TransformationFunction
Activation function of all layers.
Default: ``Transformation.TANH``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of a layer.
Default: ``Aggregation.SUM``
arity : int
@@ -47,8 +48,8 @@ def __init__(
input_name: str,
parent_map_name: str,
max_children: int = 2,
- activation: Transformation = Transformation.TANH,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.TANH,
+ aggregation: AggregationFunction = Aggregation.SUM,
arity: int = 1,
):
self.input_size = input_size
@@ -73,7 +74,6 @@ def __call__(self):
rules = [
(output_rel(head_terms) <= (input_rel(head_terms), parent_map_rel(V.P))) | metadata,
- output_rel / len(head_terms) | [Transformation.IDENTITY],
]
body = []
diff --git a/neuralogic/nn/module/general/transformer.py b/neuralogic/nn/module/general/transformer.py
index 291012d7..17c3516f 100644
--- a/neuralogic/nn/module/general/transformer.py
+++ b/neuralogic/nn/module/general/transformer.py
@@ -143,16 +143,13 @@ def __call__(self):
return [
*attention(),
(R.get(norm_name)(terms) <= (R.get(attn_name)(terms), R.get(data_name)(terms))) | [Transformation.NORM],
- R.get(norm_name) / self.arity | [Transformation.IDENTITY],
*mlp(),
(output_rel(terms) <= (R.get(norm_name)(terms), R.get(mlp_name)(terms))) | [Transformation.NORM],
- output_rel / self.arity | [Transformation.IDENTITY],
]
return [
*attention(),
(output_rel(terms) <= (R.get(attn_name)(terms), R.get(data_name)(terms))) | [Transformation.NORM],
- output_rel / self.arity | [Transformation.IDENTITY],
]
diff --git a/neuralogic/nn/module/gnn/appnp.py b/neuralogic/nn/module/gnn/appnp.py
index db90ad57..6d2ff9a8 100644
--- a/neuralogic/nn/module/gnn/appnp.py
+++ b/neuralogic/nn/module/gnn/appnp.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -73,10 +74,10 @@ class APPNPConv(Module):
Number of iterations
alpha : float
Teleport probability
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
@@ -89,8 +90,8 @@ def __init__(
edge_name: str,
k: int,
alpha: float,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -104,7 +105,7 @@ def __init__(
def __call__(self):
head = R.get(self.output_name)(V.I)
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
edge = R.get(self.edge_name)
feature = R.get(self.feature_name)
@@ -120,8 +121,6 @@ def __call__(self):
(k_head <= (R.get(f"{self.output_name}__{k - 1}")(V.J)[1 - self.alpha].fixed(), edge(V.J, V.I)))
| metadata
)
- rules.append(R.get(f"{self.output_name}__{k}") / 1 | Metadata(transformation=Transformation.IDENTITY))
-
if self.k == 1:
output_rule = head <= (feature(V.J)[1 - self.alpha].fixed(), edge(V.J, V.I))
else:
diff --git a/neuralogic/nn/module/gnn/gatv2.py b/neuralogic/nn/module/gnn/gatv2.py
index 13d94cf8..aaea79e4 100644
--- a/neuralogic/nn/module/gnn/gatv2.py
+++ b/neuralogic/nn/module/gnn/gatv2.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
@@ -23,7 +24,7 @@ class GATv2Conv(Module):
Edge predicate name to use for neighborhood relations.
share_weights : bool
Share weights in attention. Default: ``False``
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
@@ -37,7 +38,7 @@ def __init__(
feature_name: str,
edge_name: str,
share_weights: bool = False,
- activation: Transformation = Transformation.IDENTITY,
+ activation: TransformationFunction = Transformation.IDENTITY,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -55,9 +56,7 @@ def __call__(self):
attention = R.get(f"{self.output_name}__attention")
attention_metadata = Metadata(transformation=Transformation.LEAKY_RELU)
- metadata = Metadata(
- transformation=Transformation.IDENTITY, aggregation=Aggregation.SUM, combination=Combination.PRODUCT
- )
+ metadata = Metadata(aggregation=Aggregation.SUM, combination=Combination.PRODUCT)
head = R.get(self.output_name)
feature = R.get(self.feature_name)
diff --git a/neuralogic/nn/module/gnn/gcn.py b/neuralogic/nn/module/gnn/gcn.py
index 1a5265f5..9d6f18c2 100644
--- a/neuralogic/nn/module/gnn/gcn.py
+++ b/neuralogic/nn/module/gnn/gcn.py
@@ -1,5 +1,6 @@
from typing import Optional
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
@@ -24,10 +25,10 @@ class GCNConv(Module):
Feature predicate name to get features from.
edge_name : str
Edge predicate name to use for neighborhood relations.
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
add_self_loops : Optional[bool]
@@ -45,8 +46,8 @@ def __init__(
output_name: str,
feature_name: str,
edge_name: str,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
add_self_loops: Optional[bool] = None,
normalize: bool = True,
):
@@ -68,11 +69,7 @@ def __init__(
def __call__(self):
head = R.get(self.output_name)(V.I)[self.out_channels, self.in_channels]
- metadata = Metadata(
- transformation=Transformation.IDENTITY, aggregation=self.aggregation, combination=Combination.PRODUCT
- )
-
- id_metadata = Metadata(transformation=Transformation.IDENTITY)
+ metadata = Metadata(aggregation=self.aggregation, combination=Combination.PRODUCT)
edge = R.get(self.edge_name)
edge_count = R.get(f"{self.output_name}__edge_count")
@@ -86,12 +83,11 @@ def __call__(self):
self_loops = [
edge(V.I, V.I)[1.0].fixed(),
- (edge(V.I, V.J) <= (R.get(self.edge_name)(V.I, V.J))) | id_metadata,
- edge / 2 | id_metadata,
+ edge(V.I, V.J) <= (R.get(self.edge_name)(V.I, V.J)),
]
if self.normalize:
- count_metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=Aggregation.COUNT)
+ count_metadata = Metadata(aggregation=Aggregation.COUNT)
body = [R.get(self.feature_name)(V.J), edge(V.J, V.I), Transformation.SQRT(edge_count(V.J, V.I))]
normalization = [
diff --git a/neuralogic/nn/module/gnn/gen.py b/neuralogic/nn/module/gnn/gen.py
index 9467cb03..880616fc 100644
--- a/neuralogic/nn/module/gnn/gen.py
+++ b/neuralogic/nn/module/gnn/gen.py
@@ -1,5 +1,6 @@
from typing import Optional
+from neuralogic.core.constructs.function.function import AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
@@ -24,7 +25,7 @@ class GENConv(Module):
Feature predicate name to get features from.
edge_name : str
Edge predicate name to use for neighborhood relations.
- aggregation : Aggregation
+ aggregation : AggregationFunction
The aggregation function.
Default: ``Aggregation.SOFTMAX``
num_layers : int
@@ -51,7 +52,7 @@ def __init__(
output_name: str,
feature_name: str,
edge_name: str,
- aggregation: Aggregation = Aggregation.SOFTMAX,
+ aggregation: AggregationFunction = Aggregation.SOFTMAX,
num_layers: int = 2,
expansion: int = 2,
eps: float = 1e-7,
@@ -88,11 +89,7 @@ def __call__(self):
e_proj = []
if self.edge_dim is not None and self.out_channels != self.edge_dim:
e = R.get(f"{self.output_name}__gen_edge_proj")
- e_proj = [
- (e(V.I, V.J)[self.out_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J))
- | Metadata(transformation=Transformation.IDENTITY),
- e / 2 | Metadata(transformation=Transformation.IDENTITY),
- ]
+ e_proj = [e(V.I, V.J)[self.out_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J)]
channels = [self.out_channels]
for _ in range(self.num_layers - 1):
@@ -112,14 +109,8 @@ def __call__(self):
*e_proj,
(feat_sum(V.I, V.J) <= (j_feat, e(V.J, V.I)))
| Metadata(transformation=Transformation.RELU, combination=Combination.SUM),
- feat_sum / 2 | Metadata(transformation=Transformation.IDENTITY),
(feat_agg(V.I) <= (feat_sum(V.I, V.J), eps))
- | Metadata(
- transformation=Transformation.IDENTITY, aggregation=self.aggregation, combination=Combination.SUM
- ),
- feat_agg / 1 | Metadata(transformation=Transformation.IDENTITY),
- (out(V.I) <= (i_feat, feat_agg(V.I)))
- | Metadata(transformation=Transformation.IDENTITY, combination=Combination.SUM),
- out / 1 | Metadata(transformation=Transformation.IDENTITY),
+ | Metadata(aggregation=self.aggregation, combination=Combination.SUM),
+ (out(V.I) <= (i_feat, feat_agg(V.I))) | Metadata(combination=Combination.SUM),
*mlp(),
]
diff --git a/neuralogic/nn/module/gnn/gin.py b/neuralogic/nn/module/gnn/gin.py
index 82fff829..d1008f51 100644
--- a/neuralogic/nn/module/gnn/gin.py
+++ b/neuralogic/nn/module/gnn/gin.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -12,8 +13,8 @@ def __init__(
output_name: str,
feature_name: str,
edge_name: str,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -29,12 +30,11 @@ def __call__(self):
head = R.get(self.output_name)(V.I)[self.out_channels, self.in_channels]
embed = R.get(f"embed__{self.output_name}")
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
return [
(head <= (R.get(self.feature_name)(V.J), R.get(self.edge_name)(V.J, V.I))) | metadata,
(embed(V.I) <= R.get(self.feature_name)(V.I)) | metadata,
(head <= embed(V.I)[self.in_channels, self.in_channels]) | Metadata(transformation=self.activation),
- embed / 1 | Metadata(transformation=Transformation.IDENTITY),
R.get(self.output_name) / 1 | Metadata(transformation=self.activation),
]
diff --git a/neuralogic/nn/module/gnn/gine.py b/neuralogic/nn/module/gnn/gine.py
index d7e312ca..3bf85a3f 100644
--- a/neuralogic/nn/module/gnn/gine.py
+++ b/neuralogic/nn/module/gnn/gine.py
@@ -68,23 +68,13 @@ def __call__(self):
e_proj = []
if self.edge_dim is not None:
e = R.get(f"{self.nn_name}__gine_edge_proj")
- e_proj = [
- (e(V.I, V.J)[self.in_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J))
- | Metadata(transformation=Transformation.IDENTITY),
- e / 2 | Metadata(transformation=Transformation.IDENTITY),
- ]
+ e_proj = [e(V.I, V.J)[self.in_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J)]
return [
*e_proj,
(feat_sum(V.I, V.J) <= (x(V.J), e(V.J, V.I)))
| Metadata(transformation=Transformation.RELU, combination=Combination.SUM),
- feat_sum / 2 | Metadata(transformation=Transformation.IDENTITY),
- (feat_agg(V.I) <= feat_sum(V.I, V.J))
- | Metadata(transformation=Transformation.IDENTITY, aggregation=Aggregation.SUM),
- feat_agg / 1 | Metadata(transformation=Transformation.IDENTITY),
- (out(V.I) <= (x_eps, feat_agg(V.I)))
- | Metadata(transformation=Transformation.IDENTITY, combination=Combination.SUM),
- out / 1 | Metadata(transformation=Transformation.IDENTITY),
- (R.get(self.nn_name)(V.I) <= out(V.I)) | Metadata(transformation=Transformation.IDENTITY),
- R.get(self.nn_name) / 1 | Metadata(transformation=Transformation.IDENTITY),
+ (feat_agg(V.I) <= feat_sum(V.I, V.J)) | Metadata(aggregation=Aggregation.SUM),
+ (out(V.I) <= (x_eps, feat_agg(V.I))) | Metadata(combination=Combination.SUM),
+ R.get(self.nn_name)(V.I) <= out(V.I),
]
diff --git a/neuralogic/nn/module/gnn/gsage.py b/neuralogic/nn/module/gnn/gsage.py
index 7941b9f1..5a1c7863 100644
--- a/neuralogic/nn/module/gnn/gsage.py
+++ b/neuralogic/nn/module/gnn/gsage.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -35,10 +36,10 @@ class SAGEConv(Module):
Feature predicate name to get features from.
edge_name : str
Edge predicate name to use for neighborhood relations.
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.AVG``
@@ -51,8 +52,8 @@ def __init__(
output_name: str,
feature_name: str,
edge_name: str,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.AVG,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.AVG,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -66,7 +67,7 @@ def __init__(
def __call__(self):
head = R.get(self.output_name)(V.I)[self.out_channels, self.in_channels]
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
return [
(head <= (R.get(self.feature_name)(V.J), R.get(self.edge_name)(V.J, V.I))) | metadata,
diff --git a/neuralogic/nn/module/gnn/res_gated.py b/neuralogic/nn/module/gnn/res_gated.py
index c855015e..606952ce 100644
--- a/neuralogic/nn/module/gnn/res_gated.py
+++ b/neuralogic/nn/module/gnn/res_gated.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import AggregationFunction, TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
@@ -63,13 +64,13 @@ class ResGatedGraphConv(Module):
Feature predicate name to get features from.
edge_name : str
Edge predicate name to use for neighborhood relations.
- gating_activation : Transformation
+ gating_activation : TransformationFunction
Gating activation function.
Default: ``Transformation.SIGMOID``
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
@@ -82,9 +83,9 @@ def __init__(
output_name: str,
feature_name: str,
edge_name: str,
- gating_activation: Transformation = Transformation.SIGMOID,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ gating_activation: TransformationFunction = Transformation.SIGMOID,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -103,14 +104,12 @@ def __call__(self):
gate = R.get(f"{self.output_name}__gate")
w = self.out_channels, self.in_channels
- prod_metadata = Metadata(
- combination=Combination.ELPRODUCT, transformation=Transformation.IDENTITY, aggregation=self.aggregation
- )
+ prod_metadata = Metadata(combination=Combination.ELPRODUCT, aggregation=self.aggregation)
return [
- (gate(V.I, V.J) <= (feature(V.I)[w], feature(V.J)[w])) | [Transformation.IDENTITY],
+ gate(V.I, V.J) <= (feature(V.I)[w], feature(V.J)[w]),
gate / 2 | Metadata(transformation=self.gating_activation),
- (head <= feature(V.I)[w]) | [Transformation.IDENTITY],
+ head <= feature(V.I)[w],
(head <= (gate(V.I, V.J), feature(V.J)[w], R.get(self.edge_name)(V.J, V.I))) | prod_metadata,
R.get(self.output_name) / 1 | Metadata(transformation=self.activation),
]
diff --git a/neuralogic/nn/module/gnn/rgcn.py b/neuralogic/nn/module/gnn/rgcn.py
index 27bc4414..718dc4c8 100644
--- a/neuralogic/nn/module/gnn/rgcn.py
+++ b/neuralogic/nn/module/gnn/rgcn.py
@@ -1,5 +1,6 @@
from typing import List, Optional
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -84,10 +85,10 @@ class RGCNConv(Module):
are used instead.
relations : List[str]
List of relations' names
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
@@ -101,8 +102,8 @@ def __init__(
feature_name: str,
edge_name: Optional[str],
relations: List[str],
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.AVG,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.AVG,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -118,7 +119,7 @@ def __init__(
def __call__(self):
head = R.get(self.output_name)(V.I)
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
feature = R.get(self.feature_name)(V.J)[self.out_channels, self.in_channels]
if self.edge_name is not None:
diff --git a/neuralogic/nn/module/gnn/sg.py b/neuralogic/nn/module/gnn/sg.py
index ce549a5e..8227aa91 100644
--- a/neuralogic/nn/module/gnn/sg.py
+++ b/neuralogic/nn/module/gnn/sg.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -59,10 +60,10 @@ class SGConv(Module):
k : int
Number of hops.
Default: ``1``
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
@@ -76,8 +77,8 @@ def __init__(
feature_name: str,
edge_name: str,
k: int = 1,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -93,9 +94,7 @@ def __init__(
def __call__(self):
head = R.get(self.output_name)(V.I0)[self.out_channels, self.in_channels]
- metadata = Metadata(
- transformation=Transformation.IDENTITY, aggregation=self.aggregation, duplicit_grounding=True
- )
+ metadata = Metadata(aggregation=self.aggregation, duplicit_grounding=True)
edge = R.get(self.edge_name)
feature = R.get(self.feature_name)
diff --git a/neuralogic/nn/module/gnn/tag.py b/neuralogic/nn/module/gnn/tag.py
index f18f0482..10231d94 100644
--- a/neuralogic/nn/module/gnn/tag.py
+++ b/neuralogic/nn/module/gnn/tag.py
@@ -1,3 +1,4 @@
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -64,10 +65,10 @@ class TAGConv(Module):
k : int
Number of hops.
Default: ``2``
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.IDENTITY``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.SUM``
@@ -81,8 +82,8 @@ def __init__(
feature_name: str,
edge_name: str,
k: int = 2,
- activation: Transformation = Transformation.IDENTITY,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.IDENTITY,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -96,7 +97,7 @@ def __init__(
self.aggregation = aggregation
def __call__(self):
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
head = R.get(self.output_name)
feature = R.get(self.feature_name)
edge = R.get(self.edge_name)
diff --git a/neuralogic/nn/module/meta/magnn.py b/neuralogic/nn/module/meta/magnn.py
index d3ecc8fa..a88e8427 100644
--- a/neuralogic/nn/module/meta/magnn.py
+++ b/neuralogic/nn/module/meta/magnn.py
@@ -1,5 +1,6 @@
from typing import Optional, List
+from neuralogic.core.constructs.function.function import AggregationFunction, TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -36,9 +37,12 @@ class MAGNNMean(Module):
Metapath type predicate name. If none, ``meta_paths`` will be used instead.
meta_paths : List[str]
Name of types forming a single metapath.
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.SIGMOID``
+ aggregation : AggregationFunction
+ Aggregation function of the output.
+ Default: ``Aggreagtion.SUM``
"""
def __init__(
@@ -48,8 +52,8 @@ def __init__(
relation_name: str,
type_name: Optional[str],
meta_paths: List[str],
- activation: Transformation = Transformation.SIGMOID,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.SIGMOID,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -61,7 +65,7 @@ def __init__(
self.activation = activation
def __call__(self):
- metadata = Metadata(duplicit_grounding=True, transformation=Transformation.IDENTITY)
+ metadata = Metadata(duplicit_grounding=True)
length = len(self.meta_paths)
feature = R.get(self.feature_name)
relation = R.get(self.relation_name)
@@ -122,6 +126,9 @@ class MAGNNLinear(MAGNNMean):
activation : Transformation
Activation function of the output.
Default: ``Transformation.SIGMOID``
+ aggregation : AggregationFunction
+ Aggregation function of the output.
+ Default: ``Aggreagtion.SUM``
"""
def __init__(
@@ -133,8 +140,8 @@ def __init__(
relation_name: str,
type_name: Optional[str],
meta_paths: List[str],
- activation: Transformation = Transformation.SIGMOID,
- aggregation: Aggregation = Aggregation.SUM,
+ activation: TransformationFunction = Transformation.SIGMOID,
+ aggregation: AggregationFunction = Aggregation.SUM,
):
super().__init__(output_name, feature_name, relation_name, type_name, meta_paths, activation, aggregation)
self.in_channels = in_channels
diff --git a/neuralogic/nn/module/meta/meta.py b/neuralogic/nn/module/meta/meta.py
index c9f5b6ca..2f1a828e 100644
--- a/neuralogic/nn/module/meta/meta.py
+++ b/neuralogic/nn/module/meta/meta.py
@@ -1,5 +1,6 @@
from typing import List, Optional
+from neuralogic.core.constructs.function.function import TransformationFunction, AggregationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
@@ -35,10 +36,10 @@ class MetaConv(Module):
Role predicate name to use for role relations. When :code:`None`, elements from :code:`roles` are used instead.
roles : List[str]
List of relations' names
- activation : Transformation
+ activation : TransformationFunction
Activation function of the output.
Default: ``Transformation.SIGMOID``
- aggregation : Aggregation
+ aggregation : AggregationFunction
Aggregation function of nodes' neighbors.
Default: ``Aggregation.AVG``
@@ -52,8 +53,8 @@ def __init__(
feature_name: str,
role_name: Optional[str],
roles: List[str],
- activation: Transformation = Transformation.SIGMOID,
- aggregation: Aggregation = Aggregation.AVG,
+ activation: TransformationFunction = Transformation.SIGMOID,
+ aggregation: AggregationFunction = Aggregation.AVG,
):
self.output_name = output_name
self.feature_name = feature_name
@@ -71,24 +72,19 @@ def __call__(self):
head = R.get(self.output_name)(V.I)
role_head = R.get(f"{self.output_name}__roles")
- metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation)
+ metadata = Metadata(aggregation=self.aggregation)
feature = R.get(self.feature_name)(V.J)[self.out_channels, self.in_channels]
if self.role_name is not None:
role_rules = [
- ((role_head(V.I, role) <= (feature, R.get(self.role_name)(V.J, role, V.I))) | [Transformation.IDENTITY])
- for role in self.roles
+ (role_head(V.I, role) <= (feature, R.get(self.role_name)(V.J, role, V.I))) for role in self.roles
]
else:
- role_rules = [
- ((role_head(V.I, role) <= (feature, R.get(role)(V.J, V.I))) | [Transformation.IDENTITY])
- for role in self.roles
- ]
+ role_rules = [(role_head(V.I, role) <= (feature, R.get(role)(V.J, V.I))) for role in self.roles]
return [
(head <= role_head(V.I, V.R)) | metadata,
(head <= R.get(self.feature_name)(V.I)[self.out_channels, self.in_channels]) | metadata,
*role_rules,
R.get(self.output_name) / 1 | Metadata(transformation=self.activation),
- role_head / 2 | [Transformation.IDENTITY],
]
diff --git a/neuralogic/utils/data/datasets/molecules/atomEmbeddings3.txt b/neuralogic/utils/data/datasets/molecules/atomEmbeddings3.txt
index b3c725e4..932144b1 100644
--- a/neuralogic/utils/data/datasets/molecules/atomEmbeddings3.txt
+++ b/neuralogic/utils/data/datasets/molecules/atomEmbeddings3.txt
@@ -1,98 +1,98 @@
-{3,1} atom_embed(A) :- sb(A).
-{3,1} atom_embed(A) :- b(A).
-{3,1} atom_embed(A) :- c_28(A).
-{3,1} atom_embed(A) :- ru(A).
-{3,1} atom_embed(A) :- rh(A).
-{3,1} atom_embed(A) :- br(A).
-{3,1} atom_embed(A) :- n_ar(A).
-{3,1} atom_embed(A) :- fe(A).
-{3,1} atom_embed(A) :- o_42(A).
-{3,1} atom_embed(A) :- pb(A).
-{3,1} atom_embed(A) :- c_194(A).
-{3,1} atom_embed(A) :- ni(A).
-{3,1} atom_embed(A) :- cl_93(A).
-{3,1} atom_embed(A) :- c(A).
-{3,1} atom_embed(A) :- in(A).
-{3,1} atom_embed(A) :- c_10(A).
-{3,1} atom_embed(A) :- p(A).
-{3,1} atom_embed(A) :- si(A).
-{3,1} atom_embed(A) :- ca(A).
-{3,1} atom_embed(A) :- c_19(A).
-{3,1} atom_embed(A) :- c_230(A).
-{3,1} atom_embed(A) :- pt(A).
-{3,1} atom_embed(A) :- v(A).
-{3,1} atom_embed(A) :- i(A).
-{3,1} atom_embed(A) :- n_31(A).
-{3,1} atom_embed(A) :- br_94(A).
-{3,1} atom_embed(A) :- n_am(A).
-{3,1} atom_embed(A) :- h(A).
-{3,1} atom_embed(A) :- au(A).
-{3,1} atom_embed(A) :- k(A).
-{3,1} atom_embed(A) :- cd(A).
-{3,1} atom_embed(A) :- n(A).
-{3,1} atom_embed(A) :- n_38(A).
-{3,1} atom_embed(A) :- h_3(A).
-{3,1} atom_embed(A) :- n_1(A).
-{3,1} atom_embed(A) :- c_25(A).
-{3,1} atom_embed(A) :- o_51(A).
-{3,1} atom_embed(A) :- ag(A).
-{3,1} atom_embed(A) :- bi(A).
-{3,1} atom_embed(A) :- c_22(A).
-{3,1} atom_embed(A) :- o_2(A).
-{3,1} atom_embed(A) :- o_3(A).
-{3,1} atom_embed(A) :- c_2(A).
-{3,1} atom_embed(A) :- o_40(A).
-{3,1} atom_embed(A) :- o(A).
-{3,1} atom_embed(A) :- o_52(A).
-{3,1} atom_embed(A) :- c_14(A).
-{3,1} atom_embed(A) :- s_o2(A).
-{3,1} atom_embed(A) :- n_32(A).
-{3,1} atom_embed(A) :- n_2(A).
-{3,1} atom_embed(A) :- f_92(A).
-{3,1} atom_embed(A) :- as(A).
-{3,1} atom_embed(A) :- sn(A).
-{3,1} atom_embed(A) :- c_3(A).
-{3,1} atom_embed(A) :- o_41(A).
-{3,1} atom_embed(A) :- h_1(A).
-{3,1} atom_embed(A) :- c_16(A).
-{3,1} atom_embed(A) :- n_35(A).
-{3,1} atom_embed(A) :- ba(A).
-{3,1} atom_embed(A) :- c_232(A).
-{3,1} atom_embed(A) :- cr(A).
-{3,1} atom_embed(A) :- c_21(A).
-{3,1} atom_embed(A) :- c_26(A).
-{3,1} atom_embed(A) :- n_36(A).
-{3,1} atom_embed(A) :- c_ar(A).
-{3,1} atom_embed(A) :- s_o(A).
-{3,1} atom_embed(A) :- n_4(A).
-{3,1} atom_embed(A) :- i_95(A).
-{3,1} atom_embed(A) :- er(A).
-{3,1} atom_embed(A) :- o_49(A).
-{3,1} atom_embed(A) :- ge(A).
-{3,1} atom_embed(A) :- zn(A).
-{3,1} atom_embed(A) :- se(A).
-{3,1} atom_embed(A) :- c_195(A).
-{3,1} atom_embed(A) :- n_pl3(A).
-{3,1} atom_embed(A) :- te(A).
-{3,1} atom_embed(A) :- p_3(A).
-{3,1} atom_embed(A) :- s(A).
-{3,1} atom_embed(A) :- co(A).
-{3,1} atom_embed(A) :- o_45(A).
-{3,1} atom_embed(A) :- na(A).
-{3,1} atom_embed(A) :- o_co2(A).
-{3,1} atom_embed(A) :- h_8(A).
-{3,1} atom_embed(A) :- s_2(A).
-{3,1} atom_embed(A) :- hg(A).
-{3,1} atom_embed(A) :- eu(A).
-{3,1} atom_embed(A) :- n_34(A).
-{3,1} atom_embed(A) :- s_3(A).
-{3,1} atom_embed(A) :- cu(A).
-{3,1} atom_embed(A) :- mn(A).
-{3,1} atom_embed(A) :- n_3(A).
-{3,1} atom_embed(A) :- c_27(A).
-{3,1} atom_embed(A) :- c_29(A).
-{3,1} atom_embed(A) :- c_1(A).
-{3,1} atom_embed(A) :- pd(A).
-{3,1} atom_embed(A) :- cl(A).
-{3,1} atom_embed(A) :- f(A).
-{3,1} atom_embed(A) :- o_50(A).
+{3,1} atom_embed(A) :- sb(A). [activation=tanh]
+{3,1} atom_embed(A) :- b(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_28(A). [activation=tanh]
+{3,1} atom_embed(A) :- ru(A). [activation=tanh]
+{3,1} atom_embed(A) :- rh(A). [activation=tanh]
+{3,1} atom_embed(A) :- br(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_ar(A). [activation=tanh]
+{3,1} atom_embed(A) :- fe(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_42(A). [activation=tanh]
+{3,1} atom_embed(A) :- pb(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_194(A). [activation=tanh]
+{3,1} atom_embed(A) :- ni(A). [activation=tanh]
+{3,1} atom_embed(A) :- cl_93(A). [activation=tanh]
+{3,1} atom_embed(A) :- c(A). [activation=tanh]
+{3,1} atom_embed(A) :- in(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_10(A). [activation=tanh]
+{3,1} atom_embed(A) :- p(A). [activation=tanh]
+{3,1} atom_embed(A) :- si(A). [activation=tanh]
+{3,1} atom_embed(A) :- ca(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_19(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_230(A). [activation=tanh]
+{3,1} atom_embed(A) :- pt(A). [activation=tanh]
+{3,1} atom_embed(A) :- v(A). [activation=tanh]
+{3,1} atom_embed(A) :- i(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_31(A). [activation=tanh]
+{3,1} atom_embed(A) :- br_94(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_am(A). [activation=tanh]
+{3,1} atom_embed(A) :- h(A). [activation=tanh]
+{3,1} atom_embed(A) :- au(A). [activation=tanh]
+{3,1} atom_embed(A) :- k(A). [activation=tanh]
+{3,1} atom_embed(A) :- cd(A). [activation=tanh]
+{3,1} atom_embed(A) :- n(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_38(A). [activation=tanh]
+{3,1} atom_embed(A) :- h_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_1(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_25(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_51(A). [activation=tanh]
+{3,1} atom_embed(A) :- ag(A). [activation=tanh]
+{3,1} atom_embed(A) :- bi(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_22(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_2(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_2(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_40(A). [activation=tanh]
+{3,1} atom_embed(A) :- o(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_52(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_14(A). [activation=tanh]
+{3,1} atom_embed(A) :- s_o2(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_32(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_2(A). [activation=tanh]
+{3,1} atom_embed(A) :- f_92(A). [activation=tanh]
+{3,1} atom_embed(A) :- as(A). [activation=tanh]
+{3,1} atom_embed(A) :- sn(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_41(A). [activation=tanh]
+{3,1} atom_embed(A) :- h_1(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_16(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_35(A). [activation=tanh]
+{3,1} atom_embed(A) :- ba(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_232(A). [activation=tanh]
+{3,1} atom_embed(A) :- cr(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_21(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_26(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_36(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_ar(A). [activation=tanh]
+{3,1} atom_embed(A) :- s_o(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_4(A). [activation=tanh]
+{3,1} atom_embed(A) :- i_95(A). [activation=tanh]
+{3,1} atom_embed(A) :- er(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_49(A). [activation=tanh]
+{3,1} atom_embed(A) :- ge(A). [activation=tanh]
+{3,1} atom_embed(A) :- zn(A). [activation=tanh]
+{3,1} atom_embed(A) :- se(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_195(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_pl3(A). [activation=tanh]
+{3,1} atom_embed(A) :- te(A). [activation=tanh]
+{3,1} atom_embed(A) :- p_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- s(A). [activation=tanh]
+{3,1} atom_embed(A) :- co(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_45(A). [activation=tanh]
+{3,1} atom_embed(A) :- na(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_co2(A). [activation=tanh]
+{3,1} atom_embed(A) :- h_8(A). [activation=tanh]
+{3,1} atom_embed(A) :- s_2(A). [activation=tanh]
+{3,1} atom_embed(A) :- hg(A). [activation=tanh]
+{3,1} atom_embed(A) :- eu(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_34(A). [activation=tanh]
+{3,1} atom_embed(A) :- s_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- cu(A). [activation=tanh]
+{3,1} atom_embed(A) :- mn(A). [activation=tanh]
+{3,1} atom_embed(A) :- n_3(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_27(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_29(A). [activation=tanh]
+{3,1} atom_embed(A) :- c_1(A). [activation=tanh]
+{3,1} atom_embed(A) :- pd(A). [activation=tanh]
+{3,1} atom_embed(A) :- cl(A). [activation=tanh]
+{3,1} atom_embed(A) :- f(A). [activation=tanh]
+{3,1} atom_embed(A) :- o_50(A). [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/molecules/bondEmbeddings3.txt b/neuralogic/utils/data/datasets/molecules/bondEmbeddings3.txt
index 5d6ed5ab..b79ed73c 100644
--- a/neuralogic/utils/data/datasets/molecules/bondEmbeddings3.txt
+++ b/neuralogic/utils/data/datasets/molecules/bondEmbeddings3.txt
@@ -1,12 +1,12 @@
-{3,1} bond_embed(B) :- b_4(B).
-{3,1} bond_embed(B) :- b_3(B).
-{3,1} bond_embed(B) :- b_doublebond(B).
-{3,1} bond_embed(B) :- b_5(B).
-{3,1} bond_embed(B) :- b_ar(B).
-{3,1} bond_embed(B) :- b_triplebond(B).
-{3,1} bond_embed(B) :- b_2(B).
-{3,1} bond_embed(B) :- b_7(B).
-{3,1} bond_embed(B) :- b_resonantbond(B).
-{3,1} bond_embed(B) :- b_singlebond(B).
-{3,1} bond_embed(B) :- b_1(B).
-{3,1} bond_embed(B) :- b_am(B).
+{3,1} bond_embed(B) :- b_4(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_3(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_doublebond(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_5(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_ar(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_triplebond(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_2(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_7(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_resonantbond(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_singlebond(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_1(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_am(B). [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/molecules/mutagenesis/template.txt_merged.txt b/neuralogic/utils/data/datasets/molecules/mutagenesis/template.txt_merged.txt
deleted file mode 100644
index f54342e9..00000000
--- a/neuralogic/utils/data/datasets/molecules/mutagenesis/template.txt_merged.txt
+++ /dev/null
@@ -1,50 +0,0 @@
-{3,1} atom_embed(A) :- c_26(A).
-{3,1} atom_embed(A) :- c_27(A).
-{3,1} atom_embed(A) :- c_25(A).
-{3,1} atom_embed(A) :- c_28(A).
-{3,1} atom_embed(A) :- c_29(A).
-{3,1} atom_embed(A) :- o_49(A).
-{3,1} atom_embed(A) :- br_94(A).
-{3,1} atom_embed(A) :- o_42(A).
-{3,1} atom_embed(A) :- o_45(A).
-{3,1} atom_embed(A) :- o_41(A).
-{3,1} atom_embed(A) :- o_40(A).
-{3,1} atom_embed(A) :- i_95(A).
-{3,1} atom_embed(A) :- f_92(A).
-{3,1} atom_embed(A) :- h_1(A).
-{3,1} atom_embed(A) :- h_3(A).
-{3,1} atom_embed(A) :- c_10(A).
-{3,1} atom_embed(A) :- c_14(A).
-{3,1} atom_embed(A) :- c_194(A).
-{3,1} atom_embed(A) :- c_195(A).
-{3,1} atom_embed(A) :- c_16(A).
-{3,1} atom_embed(A) :- h_8(A).
-{3,1} atom_embed(A) :- c_19(A).
-{3,1} atom_embed(A) :- c_230(A).
-{3,1} atom_embed(A) :- c_232(A).
-{3,1} atom_embed(A) :- o_50(A).
-{3,1} atom_embed(A) :- n_36(A).
-{3,1} atom_embed(A) :- o_52(A).
-{3,1} atom_embed(A) :- n_35(A).
-{3,1} atom_embed(A) :- n_34(A).
-{3,1} atom_embed(A) :- o_51(A).
-{3,1} atom_embed(A) :- n_32(A).
-{3,1} atom_embed(A) :- n_31(A).
-{3,1} atom_embed(A) :- cl_93(A).
-{3,1} atom_embed(A) :- c_21(A).
-{3,1} atom_embed(A) :- c_22(A).
-{3,1} atom_embed(A) :- n_38(A).
-atom_embed/1 {3,1}
-{3,1} bond_embed(B) :- b_1(B).
-{3,1} bond_embed(B) :- b_2(B).
-{3,1} bond_embed(B) :- b_3(B).
-{3,1} bond_embed(B) :- b_4(B).
-{3,1} bond_embed(B) :- b_5(B).
-{3,1} bond_embed(B) :- b_7(B).
-bond_embed/1 {3,1}
-
-l1_embed(X) :- {3,3} atom_embed(X), {3,3} atom_embed(Y), bond(X,Y,B), bond_embed(B).
-l2_embed(X) :- {3,3} l1_embed(X), {3,3} l1_embed(Y), bond(X,Y,B), bond_embed(B).
-l3_embed(X) :- {3,3} l2_embed(X), {3,3} l2_embed(Y), bond(X,Y,B), bond_embed(B).
-
-{1,3} predict :- l3_embed(X).
diff --git a/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/embeddings.txt b/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/embeddings.txt
index c9fe1480..304da92c 100644
--- a/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/embeddings.txt
+++ b/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/embeddings.txt
@@ -1,16 +1,18 @@
-{3,1} atom_embed(A) :- c(A).
-{3,1} atom_embed(A) :- o(A).
-{3,1} atom_embed(A) :- br(A).
-{3,1} atom_embed(A) :- i(A).
-{3,1} atom_embed(A) :- f(A).
-{3,1} atom_embed(A) :- h(A).
-{3,1} atom_embed(A) :- n(A).
-{3,1} atom_embed(A) :- cl(A).
-atom_embed/1 {3,1}
-{3,1} bond_embed(B) :- b_1(B).
-{3,1} bond_embed(B) :- b_2(B).
-{3,1} bond_embed(B) :- b_3(B).
-{3,1} bond_embed(B) :- b_4(B).
-{3,1} bond_embed(B) :- b_5(B).
-{3,1} bond_embed(B) :- b_7(B).
-bond_embed/1 {3,1}
+{3,1} atom_embed(A) :- c(A). [activation=tanh]
+{3,1} atom_embed(A) :- o(A). [activation=tanh]
+{3,1} atom_embed(A) :- br(A). [activation=tanh]
+{3,1} atom_embed(A) :- i(A). [activation=tanh]
+{3,1} atom_embed(A) :- f(A). [activation=tanh]
+{3,1} atom_embed(A) :- h(A). [activation=tanh]
+{3,1} atom_embed(A) :- n(A). [activation=tanh]
+{3,1} atom_embed(A) :- cl(A). [activation=tanh]
+atom_embed / 1 {3,1}
+{3,1} bond_embed(B) :- b_1(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_2(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_3(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_4(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_5(B). [activation=tanh]
+{3,1} bond_embed(B) :- b_7(B). [activation=tanh]
+bond_embed / 1 {3,1}
+bond_embed / 1 [activation=tanh]
+atom_embed / 1 [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template.txt b/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template.txt
index 72f4c4f6..6e0c9a50 100644
--- a/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template.txt
+++ b/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template.txt
@@ -1,7 +1,12 @@
import ./embeddings.txt
-l1_embed(X) :- {3,3} atom_embed(X), {3,3} atom_embed(Y), bond(X,Y,B), bond_embed(B).
-l2_embed(X) :- {3,3} l1_embed(X), {3,3} l1_embed(Y), bond(X,Y,B), bond_embed(B).
-l3_embed(X) :- {3,3} l2_embed(X), {3,3} l2_embed(Y), bond(X,Y,B), bond_embed(B).
+l1_embed(X) :- {3,3} atom_embed(X), {3,3} atom_embed(Y), bond(X,Y,B), bond_embed(B). [activation=tanh]
+l2_embed(X) :- {3,3} l1_embed(X), {3,3} l1_embed(Y), bond(X,Y,B), bond_embed(B). [activation=tanh]
+l3_embed(X) :- {3,3} l2_embed(X), {3,3} l2_embed(Y), bond(X,Y,B), bond_embed(B). [activation=tanh]
-{1,3} predict :- l3_embed(X).
+l1_embed / 1 [activation=tanh]
+l2_embed / 1 [activation=tanh]
+l3_embed / 1 [activation=tanh]
+
+{1,3} predict :- l3_embed(X). [activation=tanh]
+predict / 0 [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/simple/trains/template.txt b/neuralogic/utils/data/datasets/simple/trains/template.txt
index d23312e3..ea4667b9 100644
--- a/neuralogic/utils/data/datasets/simple/trains/template.txt
+++ b/neuralogic/utils/data/datasets/simple/trains/template.txt
@@ -1,34 +1,41 @@
-shape(X, Y) :- {1} shape(X, Y, ellipse).
-shape(X, Y) :- {1} shape(X, Y, rectangle).
-shape(X, Y) :- {1} shape(X, Y, bucket).
-shape(X, Y) :- {1} shape(X, Y, hexagon).
-shape(X, Y) :- {1} shape(X, Y, u_shaped).
-
-length(X, Y) :- {1} length(X, Y, short).
-length(X, Y) :- {1} length(X, Y, long).
-
-sides(X, Y) :- {1} sides(X, Y, not_double).
-sides(X, Y) :- {1} sides(X, Y, double).
-
-roof(X, Y) :- {1} roof(X, Y, jagged).
-roof(X, Y) :- {1} roof(X, Y, arc).
-roof(X, Y) :- {1} roof(X, Y, none).
-roof(X, Y) :- {1} roof(X, Y, flat).
-roof(X, Y) :- {1} roof(X, Y, peaked).
-
-wheels(X, Y) :- {1} wheels(X, Y, 2).
-wheels(X, Y) :- {1} wheels(X, Y, 3).
-
-loadnum(X, Y) :- {1} loadnum(X, Y, 0).
-loadnum(X, Y) :- {1} loadnum(X, Y, 1).
-loadnum(X, Y) :- {1} loadnum(X, Y, 2).
-loadnum(X, Y) :- {1} loadnum(X, Y, 3).
-
-loadshape(X, Y) :- {1} loadshape(X, Y, hexagon).
-loadshape(X, Y) :- {1} loadshape(X, Y, triangle).
-loadshape(X, Y) :- {1} loadshape(X, Y, diamond).
-loadshape(X, Y) :- {1} loadshape(X, Y, rectangle).
-loadshape(X, Y) :- {1} loadshape(X, Y, circle).
+shape(X, Y) :- {1} shape(X, Y, ellipse). [activation=tanh]
+shape(X, Y) :- {1} shape(X, Y, rectangle). [activation=tanh]
+shape(X, Y) :- {1} shape(X, Y, bucket). [activation=tanh]
+shape(X, Y) :- {1} shape(X, Y, hexagon). [activation=tanh]
+shape(X, Y) :- {1} shape(X, Y, u_shaped). [activation=tanh]
+shape / 2 [activation=tanh]
+
+length(X, Y) :- {1} length(X, Y, short). [activation=tanh]
+length(X, Y) :- {1} length(X, Y, long). [activation=tanh]
+length / 2 [activation=tanh]
+
+sides(X, Y) :- {1} sides(X, Y, not_double). [activation=tanh]
+sides(X, Y) :- {1} sides(X, Y, double). [activation=tanh]
+sides / 2 [activation=tanh]
+
+roof(X, Y) :- {1} roof(X, Y, jagged). [activation=tanh]
+roof(X, Y) :- {1} roof(X, Y, arc). [activation=tanh]
+roof(X, Y) :- {1} roof(X, Y, none). [activation=tanh]
+roof(X, Y) :- {1} roof(X, Y, flat). [activation=tanh]
+roof(X, Y) :- {1} roof(X, Y, peaked). [activation=tanh]
+roof / 2 [activation=tanh]
+
+wheels(X, Y) :- {1} wheels(X, Y, 2). [activation=tanh]
+wheels(X, Y) :- {1} wheels(X, Y, 3). [activation=tanh]
+wheels / 2 [activation=tanh]
+
+loadnum(X, Y) :- {1} loadnum(X, Y, 0). [activation=tanh]
+loadnum(X, Y) :- {1} loadnum(X, Y, 1). [activation=tanh]
+loadnum(X, Y) :- {1} loadnum(X, Y, 2). [activation=tanh]
+loadnum(X, Y) :- {1} loadnum(X, Y, 3). [activation=tanh]
+loadnum / 2 [activation=tanh]
+
+loadshape(X, Y) :- {1} loadshape(X, Y, hexagon). [activation=tanh]
+loadshape(X, Y) :- {1} loadshape(X, Y, triangle). [activation=tanh]
+loadshape(X, Y) :- {1} loadshape(X, Y, diamond). [activation=tanh]
+loadshape(X, Y) :- {1} loadshape(X, Y, rectangle). [activation=tanh]
+loadshape(X, Y) :- {1} loadshape(X, Y, circle). [activation=tanh]
+loadshape / 2 [activation=tanh]
vagon(X, Y) :-
{1} shape(X, Y),
@@ -37,11 +44,14 @@ vagon(X, Y) :-
{1} wheels(X, Y),
{1} loadnum(X, Y),
{1} loadshape(X, Y),
- {1} roof(X, Y).
+ {1} roof(X, Y). [activation=tanh]
+vagon / 2 [activation=tanh]
-train(X) :- {1} vagon(X, 1).
-train(X) :- {1} vagon(X, 2).
-train(X) :- {1} vagon(X, 3).
-train(X) :- {1} vagon(X, 4).
+train(X) :- {1} vagon(X, 1). [activation=tanh]
+train(X) :- {1} vagon(X, 2). [activation=tanh]
+train(X) :- {1} vagon(X, 3). [activation=tanh]
+train(X) :- {1} vagon(X, 4). [activation=tanh]
+train / 1 [activation=tanh]
-direction(X) :- {1} train(X).
+direction(X) :- {1} train(X). [activation=tanh]
+direction / 1 [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/simple/xor/naive/template.txt b/neuralogic/utils/data/datasets/simple/xor/naive/template.txt
index 0b879108..9d598ccb 100644
--- a/neuralogic/utils/data/datasets/simple/xor/naive/template.txt
+++ b/neuralogic/utils/data/datasets/simple/xor/naive/template.txt
@@ -1,17 +1,27 @@
-hidden1 :- {1} a, {1} b.
-hidden2 :- {1} a, {1} b.
-hidden3 :- {1} a, {1} b.
-hidden4 :- {1} a, {1} b.
-hidden5 :- {1} a, {1} b.
-hidden6 :- {1} a, {1} b.
-hidden7 :- {1} a, {1} b.
-hidden8 :- {1} a, {1} b.
+hidden1 :- {1} a, {1} b. [activation=tanh]
+hidden2 :- {1} a, {1} b. [activation=tanh]
+hidden3 :- {1} a, {1} b. [activation=tanh]
+hidden4 :- {1} a, {1} b. [activation=tanh]
+hidden5 :- {1} a, {1} b. [activation=tanh]
+hidden6 :- {1} a, {1} b. [activation=tanh]
+hidden7 :- {1} a, {1} b. [activation=tanh]
+hidden8 :- {1} a, {1} b. [activation=tanh]
-{1} xor :- hidden1.
-{1} xor :- hidden2.
-{1} xor :- hidden3.
-{1} xor :- hidden4.
-{1} xor :- hidden5.
-{1} xor :- hidden6.
-{1} xor :- hidden7.
-{1} xor :- hidden8.
+{1} xor :- hidden1. [activation=tanh]
+{1} xor :- hidden2. [activation=tanh]
+{1} xor :- hidden3. [activation=tanh]
+{1} xor :- hidden4. [activation=tanh]
+{1} xor :- hidden5. [activation=tanh]
+{1} xor :- hidden6. [activation=tanh]
+{1} xor :- hidden7. [activation=tanh]
+{1} xor :- hidden8. [activation=tanh]
+
+xor / 0 [activation=tanh]
+hidden1 / 0 [activation=tanh]
+hidden2 / 0 [activation=tanh]
+hidden3 / 0 [activation=tanh]
+hidden4 / 0 [activation=tanh]
+hidden5 / 0 [activation=tanh]
+hidden6 / 0 [activation=tanh]
+hidden7 / 0 [activation=tanh]
+hidden8 / 0 [activation=tanh]
diff --git a/neuralogic/utils/data/datasets/simple/xor/vectorized/template.txt b/neuralogic/utils/data/datasets/simple/xor/vectorized/template.txt
index 9db67b99..7e69ec10 100644
--- a/neuralogic/utils/data/datasets/simple/xor/vectorized/template.txt
+++ b/neuralogic/utils/data/datasets/simple/xor/vectorized/template.txt
@@ -1 +1,2 @@
-{1,8} xor :- {8,2} xy.
+{1,8} xor :- {8,2} xy. [activation=tanh]
+xor / 0 [activation=tanh]
diff --git a/tests/test_evaluation_inference_engine.py b/tests/test_evaluation_inference_engine.py
index c9d93be0..58ed1e95 100644
--- a/tests/test_evaluation_inference_engine.py
+++ b/tests/test_evaluation_inference_engine.py
@@ -135,15 +135,12 @@ def test_evaluation_inference_engine_london_shortest_path() -> None:
R.connected(C.leicester_square, C.charing_cross, C.northern)[-7],
]
- metadata = Metadata(aggregation=Aggregation.MAX, transformation=Transformation.IDENTITY)
+ metadata = Metadata(aggregation=Aggregation.MAX)
template += [
(R.shortest(V.X, V.Y, C.first) <= R.connected(V.X, V.Y, V.L)) | metadata,
(R.shortest(V.X, V.Y, C.second) <= (R.connected(V.X, V.Z, V.L), R.shortest(V.Z, V.Y, V.D))) | metadata,
(R.shortest_path(V.X, V.Y) <= R.shortest(V.X, V.Y, V.D)) | metadata,
- R.shortest / 3 | Metadata(transformation=Transformation.IDENTITY),
- R.connected / 3 | Metadata(transformation=Transformation.IDENTITY),
- R.shortest_path / 2 | Metadata(transformation=Transformation.IDENTITY),
]
engine = EvaluationInferenceEngine(template)
diff --git a/tests/test_function.py b/tests/test_function.py
index afb784c5..e9dbb9d4 100644
--- a/tests/test_function.py
+++ b/tests/test_function.py
@@ -3,8 +3,8 @@
import torch
import numpy as np
-import neuralogic.nn.functional as F
from neuralogic.core import Template, R, Settings
+from neuralogic.core import F
from neuralogic.dataset import Dataset, Sample
@@ -25,8 +25,7 @@ def test_transformation_body_function(torch_fun, fun):
torch_result = torch_fun(data).detach().numpy().round(3)
template = Template()
- template += (R.h <= fun(R.input)) | [F.identity]
- template += R.h / 0 | [F.identity]
+ template += R.h <= fun(R.input)
model = template.build(Settings(iso_value_compression=False, chain_pruning=False))
dataset = Dataset([Sample(R.h, R.input[data.tolist()])])
@@ -56,8 +55,7 @@ def test_slice_function():
)
template = Template()
- template += (R.h <= F.slice(R.input, rows=(1, 3))) | [F.identity]
- template += R.h / 0 | [F.identity]
+ template += R.h <= F.slice(R.input, rows=(1, 3))
model = template.build(Settings(iso_value_compression=False, chain_pruning=False))
dataset = Dataset([Sample(R.h, [R.input[data]])])
@@ -69,7 +67,6 @@ def test_slice_function():
template = Template()
template += (R.h <= R.input) | [F.slice(rows=(1, 3))]
- template += R.h / 0 | [F.identity]
model = template.build(Settings(iso_value_compression=False, chain_pruning=False))
dataset = Dataset(Sample(R.h, [R.input[data]]))
@@ -80,7 +77,7 @@ def test_slice_function():
assert np.allclose(res, results)
template = Template()
- template += (R.h <= R.input) | [F.identity]
+ template += R.h <= R.input
template += R.h / 0 | [F.slice(rows=(1, 3))]
model = template.build(Settings(iso_value_compression=False, chain_pruning=False))
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 59ad11c3..e482af58 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -18,10 +18,10 @@ def test_rgcnconv():
template += RGCNConv(1, 2, "h1", "h0", "_edge", ["a", "b", "c"])
template_str = str(template).split("\n")
- assert template_str[0] == "h1(I) :- {2, 1} h0(I). [transformation=identity, aggregation=avg]"
- assert template_str[1] == "h1(I) :- {2, 1} h0(J), *edge(J, a, I). [transformation=identity, aggregation=avg]"
- assert template_str[2] == "h1(I) :- {2, 1} h0(J), *edge(J, b, I). [transformation=identity, aggregation=avg]"
- assert template_str[3] == "h1(I) :- {2, 1} h0(J), *edge(J, c, I). [transformation=identity, aggregation=avg]"
+ assert template_str[0] == "h1(I) :- {2, 1} h0(I). [aggregation=avg]"
+ assert template_str[1] == "h1(I) :- {2, 1} h0(J), *edge(J, a, I). [aggregation=avg]"
+ assert template_str[2] == "h1(I) :- {2, 1} h0(J), *edge(J, b, I). [aggregation=avg]"
+ assert template_str[3] == "h1(I) :- {2, 1} h0(J), *edge(J, c, I). [aggregation=avg]"
assert template_str[4] == "h1/1 [transformation=identity]"
@@ -31,10 +31,10 @@ def test_rgcnconv_relations_edge_replace():
template += RGCNConv(1, 2, "h1", "h0", None, ["a", "b", "c"], Transformation.SIGMOID)
template_str = str(template).split("\n")
- assert template_str[0] == "h1(I) :- {2, 1} h0(I). [transformation=identity, aggregation=avg]"
- assert template_str[1] == "h1(I) :- {2, 1} h0(J), a(J, I). [transformation=identity, aggregation=avg]"
- assert template_str[2] == "h1(I) :- {2, 1} h0(J), b(J, I). [transformation=identity, aggregation=avg]"
- assert template_str[3] == "h1(I) :- {2, 1} h0(J), c(J, I). [transformation=identity, aggregation=avg]"
+ assert template_str[0] == "h1(I) :- {2, 1} h0(I). [aggregation=avg]"
+ assert template_str[1] == "h1(I) :- {2, 1} h0(J), a(J, I). [aggregation=avg]"
+ assert template_str[2] == "h1(I) :- {2, 1} h0(J), b(J, I). [aggregation=avg]"
+ assert template_str[3] == "h1(I) :- {2, 1} h0(J), c(J, I). [aggregation=avg]"
assert template_str[4] == "h1/1 [transformation=sigmoid]"
@@ -45,16 +45,14 @@ def test_gcnconv():
template_str = str(template).split("\n")
assert template_str[0] == "<1.0> h1__edge(I, I)."
- assert template_str[1] == "h1__edge(I, J) :- edge(I, J). [transformation=identity]"
- assert template_str[2] == "h1__edge/2 [transformation=identity]"
- assert template_str[3] == "h1__edge_count(I, J) :- h1__edge(J, X). [transformation=identity, aggregation=count]"
- assert template_str[4] == "h1__edge_count(I, J) :- h1__edge(I, X). [transformation=identity, aggregation=count]"
- assert template_str[5] == "h1__edge_count/2 [transformation=inverse, combination=product]"
+ assert template_str[1] == "h1__edge(I, J) :- edge(I, J)."
+ assert template_str[2] == "h1__edge_count(I, J) :- h1__edge(J, X). [aggregation=count]"
+ assert template_str[3] == "h1__edge_count(I, J) :- h1__edge(I, X). [aggregation=count]"
+ assert template_str[4] == "h1__edge_count/2 [transformation=inverse, combination=product]"
assert (
- template_str[6]
- == "{2, 1} h1(I) :- h0(J), h1__edge(J, I), sqrt(h1__edge_count(J, I)). [transformation=identity, combination=product, aggregation=sum]"
+ template_str[5]
+ == "{2, 1} h1(I) :- h0(J), h1__edge(J, I), sqrt(h1__edge_count(J, I)). [combination=product, aggregation=sum]"
)
- assert template_str[7] == "h1/1 [transformation=identity]"
def test_sageconv():
@@ -63,9 +61,8 @@ def test_sageconv():
template += SAGEConv(1, 2, "h1", "h0", "_edge")
template_str = str(template).split("\n")
- assert template_str[0] == "{2, 1} h1(I) :- h0(J), *edge(J, I). [transformation=identity, aggregation=avg]"
- assert template_str[1] == "{2, 1} h1(I) :- h0(I). [transformation=identity, aggregation=avg]"
- assert template_str[2] == "h1/1 [transformation=identity]"
+ assert template_str[0] == "{2, 1} h1(I) :- h0(J), *edge(J, I). [aggregation=avg]"
+ assert template_str[1] == "{2, 1} h1(I) :- h0(I). [aggregation=avg]"
def test_tagconv():
@@ -74,14 +71,13 @@ def test_tagconv():
template += TAGConv(1, 2, "h1", "h0", "_edge")
template_str = str(template).split("\n")
- zero_hop = "{2, 1} h1(I0) :- h0(I0). [transformation=identity, aggregation=sum]"
- sec_hop = "{2, 1} h1(I0) :- h0(I1), *edge(I1, I0). [transformation=identity, aggregation=sum]"
- hop = "{2, 1} h1(I0) :- h0(I2), *edge(I1, I0), *edge(I2, I1). [transformation=identity, aggregation=sum]"
+ zero_hop = "{2, 1} h1(I0) :- h0(I0). [aggregation=sum]"
+ sec_hop = "{2, 1} h1(I0) :- h0(I1), *edge(I1, I0). [aggregation=sum]"
+ hop = "{2, 1} h1(I0) :- h0(I2), *edge(I1, I0), *edge(I2, I1). [aggregation=sum]"
assert template_str[0] == zero_hop
assert template_str[1] == sec_hop
assert template_str[2] == hop
- assert template_str[3] == "h1/1 [transformation=identity]"
template = Template()
@@ -105,9 +101,10 @@ def test_gatv2conv():
assert template_str[0] == attention
assert template_str[1] == "h1__attention/2 [transformation=softmax]"
- h1_rule = "h1(I) :- h1__attention(I, J), $h1__right={2, 1} h0(J), *edge(J, I). [transformation=identity, combination=product, aggregation=sum]"
+ h1_rule = (
+ "h1(I) :- h1__attention(I, J), $h1__right={2, 1} h0(J), *edge(J, I). [combination=product, aggregation=sum]"
+ )
assert template_str[2] == h1_rule
- assert template_str[3] == "h1/1 [transformation=identity]"
template = Template()
@@ -120,9 +117,10 @@ def test_gatv2conv():
assert template_str[0] == attention
assert template_str[1] == "h1__attention/2 [transformation=softmax]"
- h1_rule = "h1(I) :- h1__attention(I, J), $h1__right={2, 1} h0(J), *edge(J, I). [transformation=identity, combination=product, aggregation=sum]"
+ h1_rule = (
+ "h1(I) :- h1__attention(I, J), $h1__right={2, 1} h0(J), *edge(J, I). [combination=product, aggregation=sum]"
+ )
assert template_str[2] == h1_rule
- assert template_str[3] == "h1/1 [transformation=identity]"
def test_sgconv():
@@ -130,19 +128,17 @@ def test_sgconv():
template += SGConv(1, 2, "h1", "h0", "_edge", k=2)
template_str = str(template).split("\n")
- rule = "{2, 1} h1(I0) :- h0(I2), *edge(I1, I0), *edge(I2, I1). [transformation=identity, aggregation=sum, duplicit_grounding=True]"
+ rule = "{2, 1} h1(I0) :- h0(I2), *edge(I1, I0), *edge(I2, I1). [aggregation=sum, duplicit_grounding=True]"
assert template_str[0] == rule
- assert template_str[1] == "h1/1 [transformation=identity]"
template = Template()
template += SGConv(1, 2, "h1", "h0", "_edge")
template_str = str(template).split("\n")
- rule = "{2, 1} h1(I0) :- h0(I1), *edge(I1, I0). [transformation=identity, aggregation=sum, duplicit_grounding=True]"
+ rule = "{2, 1} h1(I0) :- h0(I1), *edge(I1, I0). [aggregation=sum, duplicit_grounding=True]"
assert template_str[0] == rule
- assert template_str[1] == "h1/1 [transformation=identity]"
def test_appnp():
@@ -151,26 +147,22 @@ def test_appnp():
template += APPNPConv("h1", "h0", "_edge", 1, 0.1)
template_str = str(template).split("\n")
- assert template_str[0] == "h1(I) :- <0.1> h0(I). [transformation=identity, aggregation=sum]"
- assert template_str[1] == "h1(I) :- <0.9> h0(J), *edge(J, I). [transformation=identity, aggregation=sum]"
- assert template_str[2] == "h1/1 [transformation=identity]"
+ assert template_str[0] == "h1(I) :- <0.1> h0(I). [aggregation=sum]"
+ assert template_str[1] == "h1(I) :- <0.9> h0(J), *edge(J, I). [aggregation=sum]"
template = Template()
template += APPNPConv("h1", "h0", "_edge", 3, 0.1)
template_str = str(template).split("\n")
- assert template_str[0] == "h1__1(I) :- <0.1> h0(I). [transformation=identity, aggregation=sum]"
- assert template_str[1] == "h1__1(I) :- <0.9> h0(J), *edge(J, I). [transformation=identity, aggregation=sum]"
- assert template_str[2] == "h1__1/1 [transformation=identity]"
+ assert template_str[0] == "h1__1(I) :- <0.1> h0(I). [aggregation=sum]"
+ assert template_str[1] == "h1__1(I) :- <0.9> h0(J), *edge(J, I). [aggregation=sum]"
- assert template_str[3] == "h1__2(I) :- <0.1> h0(I). [transformation=identity, aggregation=sum]"
- assert template_str[4] == "h1__2(I) :- <0.9> h1__1(J), *edge(J, I). [transformation=identity, aggregation=sum]"
- assert template_str[5] == "h1__2/1 [transformation=identity]"
+ assert template_str[2] == "h1__2(I) :- <0.1> h0(I). [aggregation=sum]"
+ assert template_str[3] == "h1__2(I) :- <0.9> h1__1(J), *edge(J, I). [aggregation=sum]"
- assert template_str[6] == "h1(I) :- <0.1> h0(I). [transformation=identity, aggregation=sum]"
- assert template_str[7] == "h1(I) :- <0.9> h1__2(J), *edge(J, I). [transformation=identity, aggregation=sum]"
- assert template_str[8] == "h1/1 [transformation=identity]"
+ assert template_str[4] == "h1(I) :- <0.1> h0(I). [aggregation=sum]"
+ assert template_str[5] == "h1(I) :- <0.9> h1__2(J), *edge(J, I). [aggregation=sum]"
def test_res_gated():
@@ -179,10 +171,9 @@ def test_res_gated():
template += ResGatedGraphConv(1, 2, "h1", "h0", "edge")
template_str = str(template).split("\n")
- rule = "h1(I) :- h1__gate(I, J), {2, 1} h0(J), edge(J, I). [transformation=identity, combination=elproduct, aggregation=sum]"
+ rule = "h1(I) :- h1__gate(I, J), {2, 1} h0(J), edge(J, I). [combination=elproduct, aggregation=sum]"
- assert template_str[0] == "h1__gate(I, J) :- {2, 1} h0(I), {2, 1} h0(J). [transformation=identity]"
+ assert template_str[0] == "h1__gate(I, J) :- {2, 1} h0(I), {2, 1} h0(J)."
assert template_str[1] == "h1__gate/2 [transformation=sigmoid]"
- assert template_str[2] == "h1(I) :- {2, 1} h0(I). [transformation=identity]"
+ assert template_str[2] == "h1(I) :- {2, 1} h0(I)."
assert template_str[3] == rule
- assert template_str[4] == "h1/1 [transformation=identity]"
diff --git a/tests/test_recurrent_modules.py b/tests/test_recurrent_modules.py
index 92ea1a77..0a205bf6 100644
--- a/tests/test_recurrent_modules.py
+++ b/tests/test_recurrent_modules.py
@@ -2,7 +2,7 @@
import pytest
import torch
-from neuralogic.core import Template, Settings, R
+from neuralogic.core import Template, Settings, R, V, Transformation
from neuralogic.dataset import Dataset, Sample
from neuralogic.nn.loss import MSE
@@ -196,3 +196,177 @@ def test_lstm_module(input_size, hidden_size, sequence_len, epochs):
result, _ = model(bd.samples)
assert np.allclose([float(x) for x in output[-1]], [float(x) for x in result[0][1]], atol=10e-5)
+
+
+@pytest.mark.parametrize(
+ "input_size, hidden_size, sequence_len, epochs",
+ [
+ (10, 5, 10, 500),
+ ],
+)
+def test_lstm_module_simple(input_size, hidden_size, sequence_len, epochs):
+ """
+ Test that PyNeuraLogic LSTM layer computes the same as PyTorch LSTM layer (with backprop).
+ """
+ torch_input = torch.randn((sequence_len, input_size))
+ h0 = torch.randn((1, hidden_size))
+ c0 = torch.randn((1, hidden_size))
+ target = torch.randn((hidden_size,))
+
+ rnn = torch.nn.LSTM(input_size, hidden_size, 1, bias=False)
+
+ template = Template()
+
+ template += [
+ R.h(0) <= R.h0,
+ R.h__c(0) <= R.c0,
+ (
+ R.h__i(V.T)
+ <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ )
+ | [Transformation.SIGMOID],
+ (
+ R.h__f(V.T)
+ <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ )
+ | [Transformation.SIGMOID],
+ (
+ R.h__o(V.T)
+ <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ )
+ | [Transformation.SIGMOID],
+ (
+ R.h__n(V.T)
+ <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ )
+ | [Transformation.TANH],
+ R.h__c(V.T) <= R.h__f(V.T) * R.h__c(V.Z) + R.h__i(V.T) * R.h__n(V.T) + R.special.next(V.Z, V.T),
+ R.h(V.T) <= R.h__o(V.T) * Transformation.TANH(R.h__c(V.T)),
+ ]
+
+ model = template.build(
+ Settings(chain_pruning=False, iso_value_compression=False, optimizer=Adam(lr=0.001), error_function=MSE())
+ )
+
+ parameters = model.parameters()
+ torch_parameters = [parameter.tolist() for parameter in rnn.parameters()]
+
+ parameters["weights"][0] = [torch_parameters[0][i] for i in range(0, hidden_size)]
+ parameters["weights"][2] = [torch_parameters[0][i] for i in range(1 * hidden_size, 1 * hidden_size + hidden_size)]
+ parameters["weights"][4] = [torch_parameters[0][i] for i in range(3 * hidden_size, 3 * hidden_size + hidden_size)]
+ parameters["weights"][6] = [torch_parameters[0][i] for i in range(2 * hidden_size, 2 * hidden_size + hidden_size)]
+
+ parameters["weights"][1] = [torch_parameters[1][i] for i in range(0, hidden_size)]
+ parameters["weights"][3] = [torch_parameters[1][i] for i in range(1 * hidden_size, 1 * hidden_size + hidden_size)]
+ parameters["weights"][5] = [torch_parameters[1][i] for i in range(3 * hidden_size, 3 * hidden_size + hidden_size)]
+ parameters["weights"][7] = [torch_parameters[1][i] for i in range(2 * hidden_size, 2 * hidden_size + hidden_size)]
+
+ model.load_state_dict(parameters)
+
+ dataset = Dataset(
+ [
+ Sample(
+ R.h(sequence_len)[target.detach().numpy().tolist()],
+ [
+ R.c0[[float(c) for c in c0[0]]],
+ R.h0[[float(h) for h in h0[0]]],
+ *[R.f(i + 1)[[float(h) for h in torch_input[i]]] for i in range(sequence_len)],
+ ],
+ ),
+ ]
+ )
+
+ bd = model.build_dataset(dataset)
+
+ optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
+ loss_fun = torch.nn.MSELoss()
+
+ for _ in range(epochs):
+ output, _ = rnn(torch_input, (h0, c0))
+ loss = loss_fun(output[-1], target)
+
+ optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ optimizer.step()
+
+ result, _ = model(bd.samples)
+ assert np.allclose([float(x) for x in output[-1]], [float(x) for x in result[0][1]], atol=10e-5)
+
+
+@pytest.mark.parametrize(
+ "input_size, hidden_size, sequence_len, epochs",
+ [
+ (10, 5, 10, 500),
+ ],
+)
+def test_gru_module_simple(input_size, hidden_size, sequence_len, epochs):
+ """Test that PyNeuraLogic GRU layer computes the same as PyTorch GRU layer (with backprop)"""
+ torch_input = torch.randn((sequence_len, input_size))
+ h0 = torch.randn((1, hidden_size))
+ target = torch.randn((hidden_size,))
+
+ rnn = torch.nn.GRU(input_size, hidden_size, 1, bias=False)
+
+ template = Template()
+
+ template += R.h(0) <= R.h0
+ template += (
+ R.h__r(V.T) <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ ) | [Transformation.SIGMOID]
+ template += (
+ R.h__z(V.T) <= R.f(V.T)[hidden_size, input_size] + R.h(V.Z)[hidden_size, hidden_size] + R.special.next(V.Z, V.T)
+ ) | [Transformation.SIGMOID]
+ template += (
+ R.h__n(V.T)
+ <= (R.h__r(V.T) * R.h(V.Z)[hidden_size, hidden_size])
+ + R.f(V.T)[hidden_size, input_size]
+ + R.special.next(V.Z, V.T)
+ ) | [Transformation.TANH]
+ template += R.h(V.T) <= (Transformation.REVERSE(R.h__z(V.T)) * R.h__n(V.T)) + (
+ R.h__z(V.T) * R.h(V.Z)
+ ) + R.special.next(V.Z, V.T)
+
+ model = template.build(
+ Settings(chain_pruning=False, iso_value_compression=False, optimizer=Adam(lr=0.001), error_function=MSE())
+ )
+
+ parameters = model.parameters()
+ torch_parameters = [parameter.tolist() for parameter in rnn.parameters()]
+
+ parameters["weights"][0] = [torch_parameters[0][i] for i in range(0, hidden_size)]
+ parameters["weights"][2] = [torch_parameters[0][i] for i in range(1 * hidden_size, 1 * hidden_size + hidden_size)]
+ parameters["weights"][5] = [torch_parameters[0][i] for i in range(2 * hidden_size, 2 * hidden_size + hidden_size)]
+
+ parameters["weights"][1] = [torch_parameters[1][i] for i in range(0, hidden_size)]
+ parameters["weights"][3] = [torch_parameters[1][i] for i in range(1 * hidden_size, 1 * hidden_size + hidden_size)]
+ parameters["weights"][4] = [torch_parameters[1][i] for i in range(2 * hidden_size, 2 * hidden_size + hidden_size)]
+
+ model.load_state_dict(parameters)
+
+ dataset = Dataset(
+ [
+ Sample(
+ R.h(sequence_len)[target.detach().numpy().tolist()],
+ [
+ R.h0[[float(h) for h in h0[0]]],
+ *[R.f(i + 1)[[float(h) for h in torch_input[i]]] for i in range(sequence_len)],
+ ],
+ )
+ ]
+ )
+
+ bd = model.build_dataset(dataset)
+
+ optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
+ loss_fun = torch.nn.MSELoss()
+
+ for _ in range(epochs):
+ output, _ = rnn(torch_input, h0)
+ loss = loss_fun(output[-1], target)
+
+ optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ optimizer.step()
+
+ result, _ = model(bd.samples)
+ assert np.allclose([float(x) for x in output[-1]], [float(x) for x in result[0][1]], atol=10e-5)
diff --git a/tests/test_settings.py b/tests/test_settings.py
index 58ddbf2a..117d13aa 100644
--- a/tests/test_settings.py
+++ b/tests/test_settings.py
@@ -18,8 +18,6 @@
"error_function": SoftEntropy(),
"initializer": Uniform(5.0),
"initializer_uniform_scale": 5.0,
- "rule_transformation": Transformation.SIGMOID,
- "relation_transformation": Transformation.RELU,
}
],
)
diff --git a/tests/test_torch_function.py b/tests/test_torch_function.py
index a276b70e..1392fda3 100644
--- a/tests/test_torch_function.py
+++ b/tests/test_torch_function.py
@@ -2,9 +2,8 @@
from torch.nn import Sequential
import neuralogic
-from neuralogic.core import Relation, Template, R
+from neuralogic.core import Relation, Template, R, F
from neuralogic.nn.torch_function import NeuraLogic
-import neuralogic.nn.functional as F
def test_torch_function_with_parameters():
@@ -12,8 +11,7 @@ def test_torch_function_with_parameters():
neuralogic.manual_seed(1)
template = Template()
- template += (Relation.xor[1, 8] <= Relation.xy) | [F.identity]
- template += Relation.xor / 0 | [F.identity]
+ template += Relation.xor[1, 8] <= Relation.xy
def to_logic(tensor_data):
return [Relation.xy[tensor_data]]
@@ -30,11 +28,7 @@ def to_logic(tensor_data):
torch.nn.Tanh(),
NeuraLogic(
template,
- [
- R.xy[
- 8,
- ]
- ],
+ [R.xy[8,]],
R.xor,
to_logic,
),
@@ -65,8 +59,7 @@ def test_torch_function_without_parameters():
neuralogic.manual_seed(1)
template = Template()
- template += (Relation.xor <= Relation.xy) | [F.identity]
- template += Relation.xor / 0 | [F.identity]
+ template += Relation.xor <= Relation.xy
def to_logic(tensor_data):
return [Relation.xy[tensor_data]]
@@ -83,11 +76,7 @@ def to_logic(tensor_data):
torch.nn.Tanh(),
NeuraLogic(
template,
- [
- R.xy[
- 8,
- ]
- ],
+ [R.xy[8,]],
R.xor,
to_logic,
),
diff --git a/tests/test_xor_generalization.py b/tests/test_xor_generalization.py
index 0ee98346..201b6510 100644
--- a/tests/test_xor_generalization.py
+++ b/tests/test_xor_generalization.py
@@ -21,13 +21,14 @@
)
def test_xor_generalization_accurate(n: int, expected: List[int]) -> None:
manual_seed(0)
- max_number_of_max_vars = 20
dataset = Dataset()
template = Template()
- template += R.xor_at(0) <= R.val_at(0)
- template += R.xor_at(V.Y)["a":1, 8] <= (R.val_at(V.Y)["b":8, 1], R.xor_at(V.X)["c":8, 1], R.special.next(V.X, V.Y))
+ template += (R.xor_at(0) <= R.val_at(0)) | [Transformation.TANH]
+ template += (
+ R.xor_at(V.Y)["a":1, 8] <= (R.val_at(V.Y)["b":8, 1], R.xor_at(V.X)["c":8, 1], R.special.next(V.X, V.Y))
+ ) | [Transformation.TANH]
dataset.add_samples(
[
@@ -38,9 +39,7 @@ def test_xor_generalization_accurate(n: int, expected: List[int]) -> None:
]
)
- settings = Settings(
- epochs=5000, rule_transformation=Transformation.TANH, relation_transformation=Transformation.IDENTITY
- )
+ settings = Settings(epochs=5000)
evaluator = get_evaluator(template, settings)
evaluator.train(dataset, generator=False)
@@ -74,17 +73,20 @@ def test_xor_generalization(n: int, expected: List[int]) -> None:
template.add_rules([
# This rule does xor for the last pair
- R.xor(V.X, V.Y)["a":1, 8] <= (
+ (R.xor(V.X, V.Y)["a":1, 8] <= (
R.x(V.X)["b":8, 1], R.x(V.Y)["c":8, 1], R.hidden.xy(V.X, V.Y), R.hidden.n(V.Y)
- ),
+ )) | [Transformation.TANH],
# This rule recursively evaluates xor for X and xor(Y, Z)
- R.xor(V.X, V.Y)["a":1, 8] <= (
+ (R.xor(V.X, V.Y)["a":1, 8] <= (
R.x(V.X)["b":8, 1], R.xor(V.Y, V.Z)["c":8, 1], R.hidden.xy(V.X, V.Y), R.hidden.xy(V.Y, V.Z)
- ),
+ )) | [Transformation.TANH],
# Helper rule so that queries are just R.xor
- (R.xor <= R.xor(0, V.X))
+ (R.xor <= R.xor(0, V.X)) | [Transformation.TANH],
+
+ R.xor / 2 | [Transformation.TANH],
+ R.xor / 0 | [Transformation.TANH],
])
# The training dataset to train xor on two inputs x(0) and x(1), n(1) is means the max index of input is 1