Skip to content

Commit

Permalink
Implement support for mixed in transformations with combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Jun 5, 2024
1 parent 1a91620 commit 91b2d33
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 122 deletions.
81 changes: 2 additions & 79 deletions neuralogic/core/constructs/function/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Dict

import jpype

from neuralogic.core.constructs.function.concat import ConcatComb, Concat
from neuralogic.core.constructs.function.function import Transformation, Combination, Aggregation, Function
from neuralogic.core.constructs.function.function_container import FContainer
from neuralogic.core.constructs.function.reshape import Reshape
from neuralogic.core.constructs.function.mixed_combination import MixedCombination
from neuralogic.core.constructs.function.slice import Slice
from neuralogic.core.constructs.function.softmax import Softmax

Expand Down Expand Up @@ -34,77 +30,4 @@
Aggregation.SOFTMAX = Softmax("SOFTMAX")


class CombinationWrap:
__slots__ = "left", "right", "combination"

def __init__(self, left, right, combination: Combination):
self.left = left
self.right = right
self.combination = combination

def __add__(self, other):
return CombinationWrap(self, other, Combination.SUM)

def __mul__(self, other):
return CombinationWrap(self, other, Combination.ELPRODUCT)

def __matmul__(self, other):
return CombinationWrap(self, other, Combination.PRODUCT)

def __str__(self):
if not isinstance(self.left, CombinationWrap) and not isinstance(self.right, CombinationWrap):
return f"{self.combination}"
if not isinstance(self.left, CombinationWrap):
return f"{self.combination}({self.right.to_str()})"
if not isinstance(self.right, CombinationWrap):
return f"{self.combination}({self.left.to_str()})"

return f"{self.combination}({self.left.to_str()}, {self.right.to_str()})"

def __iter__(self):
if isinstance(self.left, CombinationWrap):
for a in self.left:
yield a
if not isinstance(self.left, CombinationWrap):
yield self.left

if isinstance(self.right, CombinationWrap):
for a in self.right:
yield a
if not isinstance(self.right, CombinationWrap):
yield self.right

def to_combination(self) -> Combination:
combination_graph = self._get_combination_node({}, 0)
return MixedCombination(name=self.to_str(), combination_graph=combination_graph)

def _get_combination_node(self, input_counter: Dict[int, int], start_index: int = 0):
left_node = None
right_node = None

left_index = -1
right_index = -1

if isinstance(self.left, CombinationWrap):
left_node = self.left._get_combination_node(input_counter)
else:
if id(self.left) not in input_counter:
input_counter[id(self.left)] = len(input_counter) + start_index
left_index = input_counter[id(self.left)]

if isinstance(self.right, CombinationWrap):
right_node = self.right._get_combination_node(input_counter)
else:
if id(self.right) not in input_counter:
input_counter[id(self.right)] = len(input_counter) + start_index
right_index = input_counter[id(self.right)]

class_name = "cz.cvut.fel.ida.algebra.functions.combination.MixedCombination.MixedCombinationNode"

return jpype.JClass(class_name)(self.combination.get(), left_node, right_node, left_index, right_index)

def to_str(self):
return self.__str__()


__all__ = ["Transformation", "Combination", "Aggregation", "Function", "CombinationWrap"]
__all__ = ["Transformation", "Combination", "Aggregation", "Function", "FContainer"]
26 changes: 24 additions & 2 deletions neuralogic/core/constructs/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,31 @@ class Transformation(Function):
SLICE: "Transformation"
RESHAPE: "Transformation"

def get(self):
name = "".join(s.capitalize() for s in self.name.split("_"))

if name == "Transp":
name = "Transposition"
if name == "Norm":
name = "Normalization"

if name in ("Identity", "Transposition", "Softmax", "Sparsemax", "Normalization", "Slice", "Reshape"):
return jpype.JClass(f"cz.cvut.fel.ida.algebra.functions.transformation.joint.{name}")()
return jpype.JClass(f"cz.cvut.fel.ida.algebra.functions.transformation.elementwise.{name}")()

def __call__(self, *args, **kwargs):
from neuralogic.core.constructs import relation
from neuralogic.core.constructs.function.function_container import FContainer

if len(args) == 0 or args[0] is None:
return self

arg = args[0]
if isinstance(arg, relation.BaseRelation):
if arg.function is not None:
return FContainer((arg,), self)
return arg.attach_activation_function(self)
raise NotImplementedError
return FContainer(args, self)


class Combination(Function):
Expand All @@ -87,7 +102,7 @@ class Combination(Function):
COSSIM: "Combination"

def get(self):
name = self.name.capitalize()
name = "".join(s.capitalize() for s in self.name.split("_"))

if name in ("Sum", "Max", "Min", "Avg", "Count"):
return jpype.JClass(f"cz.cvut.fel.ida.algebra.functions.aggregation.{name}")()
Expand All @@ -96,6 +111,13 @@ def get(self):

return jpype.JClass(f"cz.cvut.fel.ida.algebra.functions.combination.{name}")()

def __call__(self, *args, **kwargs):
from neuralogic.core.constructs.function.function_container import FContainer

if len(args) == 0 or args[0] is None:
return self
return FContainer(args, self)


class Aggregation(Function):
AVG: "Aggregation"
Expand Down
63 changes: 63 additions & 0 deletions neuralogic/core/constructs/function/function_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Dict

import jpype

from neuralogic.core.constructs.function.function_graph import FunctionGraph
from neuralogic.core.constructs.function.function import Function, Combination


class FContainer:
__slots__ = "nodes", "function"

def __init__(self, nodes, function: Function):
self.function = function
self.nodes = nodes

def __add__(self, other):
return FContainer((self, other), Combination.SUM)

def __mul__(self, other):
return FContainer((self, other), Combination.ELPRODUCT)

def __matmul__(self, other):
return FContainer((self, other), Combination.PRODUCT)

def __str__(self):
args = ", ".join(node.to_str() for node in self.nodes if isinstance(node, FContainer))

if args:
return f"{self.function}({args})"
return f"{self.function}"

def __iter__(self):
for node in self.nodes:
if isinstance(node, FContainer):
for a in node:
yield a
else:
yield node

def to_function(self) -> Function:
graph = self._get_function_node({}, 0)
return FunctionGraph(name=self.to_str(), function_graph=graph)

def _get_function_node(self, input_counter: Dict[int, int], start_index: int = 0):
next_indices = [-1] * len(self.nodes)
next_nodes = [None] * len(self.nodes)

for i, node in enumerate(self.nodes):
if isinstance(node, FContainer):
next_nodes[i] = node._get_function_node(input_counter)
else:
idx = id(node)

if idx not in input_counter:
input_counter[idx] = len(input_counter) + start_index
next_indices[i] = input_counter[idx]

class_name = "cz.cvut.fel.ida.algebra.functions.combination.FunctionGraph.FunctionGraphNode"

return jpype.JClass(class_name)(self.function.get(), next_nodes, next_indices)

def to_str(self):
return self.__str__()
27 changes: 27 additions & 0 deletions neuralogic/core/constructs/function/function_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import jpype

from neuralogic.core.constructs.function.function import Function


class FunctionGraph(Function):
__slots__ = ("function_graph",)

def __init__(
self,
name: str,
*,
function_graph,
):
super().__init__(name)
self.function_graph = function_graph

def is_parametrized(self) -> bool:
return True

def get(self):
return jpype.JClass("cz.cvut.fel.ida.algebra.functions.combination.FunctionGraph")(
self.name, self.function_graph
)

def __str__(self):
return self.name
31 changes: 0 additions & 31 deletions neuralogic/core/constructs/function/mixed_combination.py

This file was deleted.

8 changes: 4 additions & 4 deletions neuralogic/core/constructs/java_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jpype

from neuralogic import is_initialized, initialize
from neuralogic.core.constructs.function import CombinationWrap
from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.term import Variable, Constant
from neuralogic.core.settings import SettingsProxy, Settings

Expand Down Expand Up @@ -308,7 +308,7 @@ def get_rule(self, rule):
else:
java_rule.setWeight(weight)

if isinstance(rule.body, CombinationWrap):
if isinstance(rule.body, FContainer):
processed_relations = {}
body_relation = []
for relation in rule.body:
Expand All @@ -331,9 +331,9 @@ def get_rule(self, rule):
java_rule.allowDuplicitGroundings = bool(rule.metadata.duplicit_grounding)

metadata = rule.metadata
if isinstance(rule.body, CombinationWrap):
if isinstance(rule.body, FContainer):
metadata = metadata.copy()
metadata.combination = rule.body.to_combination()
metadata.combination = rule.body.to_function()

java_rule.setMetadata(self.get_metadata(metadata, self.rule_metadata))

Expand Down
8 changes: 4 additions & 4 deletions neuralogic/core/constructs/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from neuralogic.core.constructs.predicate import Predicate
from neuralogic.core.constructs import rule, factories
from neuralogic.core.constructs.function import Transformation, Combination, CombinationWrap
from neuralogic.core.constructs.function import Transformation, Combination, FContainer


class BaseRelation:
Expand Down Expand Up @@ -127,13 +127,13 @@ def __and__(self, other) -> rule.RuleBody:
raise NotImplementedError

def __add__(self, other):
return CombinationWrap(self, other, Combination.SUM)
return FContainer((self, other), Combination.SUM)

def __mul__(self, other):
return CombinationWrap(self, other, Combination.ELPRODUCT)
return FContainer((self, other), Combination.ELPRODUCT)

def __matmul__(self, other):
return CombinationWrap(self, other, Combination.PRODUCT)
return FContainer((self, other), Combination.PRODUCT)


class WeightedRelation(BaseRelation):
Expand Down
4 changes: 2 additions & 2 deletions neuralogic/core/constructs/rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterable, Optional

from neuralogic.core.constructs.function import CombinationWrap
from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.metadata import Metadata


Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, head, body):

self.body = body

if not isinstance(self.body, CombinationWrap):
if not isinstance(self.body, FContainer):
self.body = list(body)

if self.is_ellipsis_templated():
Expand Down
Binary file modified neuralogic/jar/NeuraLogic.jar
Binary file not shown.

0 comments on commit 91b2d33

Please sign in to comment.