From 68322fc0eb8773f14d6e0002a1b85bd62a971fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Zahradn=C3=ADk?= Date: Sun, 28 Apr 2024 13:07:35 +0200 Subject: [PATCH] Show correct length of dataset when running neuralization --- neuralogic/core/builder/components.py | 20 ++++++++++++++------ neuralogic/core/builder/dataset_builder.py | 4 +--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/neuralogic/core/builder/components.py b/neuralogic/core/builder/components.py index a9e9b6d7..97ba3435 100644 --- a/neuralogic/core/builder/components.py +++ b/neuralogic/core/builder/components.py @@ -234,21 +234,29 @@ def draw_grounding( class GroundedDataset: """GroundedDataset represents grounded examples that are not neuralized yet.""" - __slots__ = "length", "_groundings", "_groundings_list", "_builder" + __slots__ = "_groundings", "_groundings_list", "_builder" - def __init__(self, groundings, length, builder): - self.length = length + def __init__(self, groundings, builder): self._groundings = groundings self._groundings_list = None self._builder = builder - def __getitem__(self, item): + def _to_list(self): if self._groundings_list is None: self._groundings = self._groundings.collect(jpype.JClass("java.util.stream.Collectors").toList()) self._groundings_list = [Grounding(g) for g in self._groundings] + + def __getitem__(self, item): + self._to_list() return self._groundings_list[item] + def __len__(self): + self._to_list() + return len(self._groundings_list) + def neuralize(self, progress: bool): if self._groundings_list is not None: - return self._builder.neuralize(self._groundings.stream(), progress, self.length) - return self._builder.neuralize(self._groundings, progress, self.length) + return self._builder.neuralize(self._groundings.stream(), progress, len(self)) + if progress: + return self._builder.neuralize(self._groundings, progress, len(self)) + return self._builder.neuralize(self._groundings, progress, 0) diff --git a/neuralogic/core/builder/dataset_builder.py b/neuralogic/core/builder/dataset_builder.py index 1a7bb742..45235cb2 100644 --- a/neuralogic/core/builder/dataset_builder.py +++ b/neuralogic/core/builder/dataset_builder.py @@ -134,7 +134,6 @@ def ground_dataset( settings.settings.parallelTraining = True builder = Builder(settings) - length = None if isinstance(dataset, datasets.Dataset): self.examples_counter = 0 @@ -171,7 +170,6 @@ def ground_dataset( queries, examples, one_query_per_example, example_queries ) - length = len(logic_samples) groundings = builder.ground_from_logic_samples(self.parsed_template, logic_samples) self.java_factory.weight_factory = weight_factory @@ -187,7 +185,7 @@ def ground_dataset( else: raise NotImplementedError - return GroundedDataset(groundings, length, builder) + return GroundedDataset(groundings, builder) def build_dataset( self,