From 337a783f1b2cae759a2714a6ee5e09eb9d11fd89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Zahradn=C3=ADk?= Date: Thu, 10 Oct 2024 22:25:59 +0200 Subject: [PATCH] Fix neuralogic as torch function --- neuralogic/core/builder/components.py | 4 ++-- neuralogic/nn/torch_function.py | 17 +++-------------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/neuralogic/core/builder/components.py b/neuralogic/core/builder/components.py index 9095853e..c82e5f4c 100644 --- a/neuralogic/core/builder/components.py +++ b/neuralogic/core/builder/components.py @@ -121,7 +121,7 @@ def get_fact(self, fact): if term_str[0] == term_str[0].upper() and term_str[0] != term_str[0].lower(): raise ValueError(f"{fact} is not a fact") - return self.get_nodes(fact, "FactAtom") + return self.get_nodes(fact, "FactNeuron") def set_fact_value(self, fact, value) -> int: for term in fact.terms: @@ -130,7 +130,7 @@ def set_fact_value(self, fact, value) -> int: if term_str[0] == term_str[0].upper() and term_str[0] != term_str[0].lower(): raise ValueError(f"{fact} is not a fact") - node = self.get_nodes(fact, "FactAtom") + node = self.get_nodes(fact, "FactNeuron") if len(node) == 0: return -1 diff --git a/neuralogic/nn/torch_function.py b/neuralogic/nn/torch_function.py index 4af7c59d..d7e7256a 100644 --- a/neuralogic/nn/torch_function.py +++ b/neuralogic/nn/torch_function.py @@ -1,4 +1,3 @@ -import json from typing import Callable, Any, List, Union import torch @@ -14,10 +13,9 @@ class _NeuraLogicFunction(Function): @staticmethod - def forward(ctx, mapping, value_factory, sample, model, number_format, dtype, *inputs): + def forward(ctx, mapping, value_factory, sample, model, dtype, *inputs): ctx.model = model ctx.sample = sample - ctx.number_format = number_format ctx.dtype = dtype ctx.mapping = mapping @@ -30,19 +28,12 @@ def forward(ctx, mapping, value_factory, sample, model, number_format, dtype, *i def backward(ctx: Any, *grad_outputs: Any) -> Any: model = ctx.model sample = ctx.sample - number_format = ctx.number_format dtype = ctx.dtype backproper, weight_updater = model.backprop(sample, -grad_outputs[0].detach().numpy()) - state_index = backproper.stateIndex gradients = tuple( - -torch.tensor( - json.loads( - str(sample.get_fact(fact).getComputationView(state_index).getGradient().toString(number_format)) - ), - dtype=dtype, - ).reshape(fact.weight.shape) + -torch.tensor(sample.get_fact(fact)[0].gradient, dtype=dtype).reshape(fact.weight.shape) for fact in ctx.mapping ) @@ -50,7 +41,7 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: trainer.updateWeights(model.strategy.getCurrentModel(), weight_updater) trainer.invalidateSample(trainer.getInvalidation(), sample.java_sample) - return (None, None, None, None, None, None, *gradients) + return (None, None, None, None, None, *gradients) class NeuraLogic(nn.Module): @@ -74,7 +65,6 @@ def __init__( self.to_logic = to_logic self.model = template.build(settings) - self.number_format = self.model.settings.settings_class.superDetailedNumberFormat dataset = Dataset(Sample(output_relation, input_facts)) self.sample = self.model.build_dataset(dataset, learnable_facts=True).samples[0] @@ -91,7 +81,6 @@ def forward(self, *inputs, **kwargs): self.value_factory, self.sample, self.model, - self.number_format, self.dtype, *(fact.weight for fact in mapping), )