Skip to content

Commit

Permalink
Add default rule aggregation to settings
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Nov 5, 2023
1 parent a35e307 commit 45aa9ca
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
11 changes: 10 additions & 1 deletion neuralogic/core/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from neuralogic.nn.init import Initializer, Uniform
from neuralogic.nn.loss import MSE, ErrorFunction
from neuralogic.core.settings.settings_proxy import SettingsProxy
from neuralogic.core.constructs.function import Transformation, Combination
from neuralogic.core.constructs.function import Transformation, Combination, Aggregation
from neuralogic.optim import Optimizer, Adam


Expand All @@ -20,6 +20,7 @@ def __init__(
initializer: Initializer = Uniform(),
rule_transformation: Transformation = Transformation.TANH,
rule_combination: Combination = Combination.SUM,
rule_aggregation: Aggregation = Aggregation.AVG,
relation_transformation: Transformation = Transformation.TANH,
relation_combination: Combination = Combination.SUM,
iso_value_compression: bool = True,
Expand Down Expand Up @@ -127,6 +128,14 @@ def rule_combination(self) -> Combination:
def rule_combination(self, value: Combination):
self._update("rule_combination", value)

@property
def rule_aggregation(self) -> Aggregation:
return self.params["rule_aggregation"]

@rule_aggregation.setter
def rule_aggregation(self, value: Aggregation):
self._update("rule_aggregation", value)

def create_proxy(self) -> SettingsProxy:
proxy = SettingsProxy(**self.params)
self._proxies.add(proxy)
Expand Down
15 changes: 14 additions & 1 deletion neuralogic/core/settings/settings_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import neuralogic
from neuralogic import is_initialized, initialize
from neuralogic.core.constructs.function import Transformation, Combination
from neuralogic.core.constructs.function import Transformation, Combination, Aggregation
from neuralogic.core.enums import Grounder
from neuralogic.nn.init import Initializer
from neuralogic.nn.loss import MSE, SoftEntropy, CrossEntropy, ErrorFunction
Expand All @@ -20,6 +20,7 @@ def __init__(
initializer: Initializer,
rule_transformation: Transformation,
rule_combination: Combination,
rule_aggregation: Aggregation,
relation_transformation: Transformation,
relation_combination: Combination,
iso_value_compression: bool,
Expand Down Expand Up @@ -223,6 +224,14 @@ def rule_combination(self) -> Combination:
def rule_combination(self, value: Combination):
self.settings.ruleNeuronCombination = self.get_combination_function(value)

@property
def rule_aggregation(self) -> Aggregation:
return Aggregation(str(self.settings.aggNeuronAggregation))

@rule_aggregation.setter
def rule_aggregation(self, value: Aggregation):
self.settings.aggNeuronAggregation = self.get_aggregation_function(value)

@property
def debug_exporting(self) -> bool:
return self.settings.debugExporting
Expand All @@ -243,6 +252,10 @@ def get_combination_function(self, combination: Combination):
combination_name = str(combination)
return self.settings_class.parseCombination(combination_name)

def get_aggregation_function(self, aggregation: Aggregation):
aggregation_name = str(aggregation)
return self.settings_class.parseCombination(aggregation_name)

def get_transformation_function(self, transformation: Transformation):
transformation_name = str(transformation)
return self.settings_class.parseTransformation(transformation_name)
Expand Down

0 comments on commit 45aa9ca

Please sign in to comment.