diff --git a/include/feature_generation/feature_generators/kwl2.hpp b/include/feature_generation/feature_generators/kwl2.hpp index 54d1f65..486ad94 100644 --- a/include/feature_generation/feature_generators/kwl2.hpp +++ b/include/feature_generation/feature_generators/kwl2.hpp @@ -31,7 +31,6 @@ namespace feature_generation { Embedding embed(const std::shared_ptr &graph) override; protected: - void init_neighbour_container() override; inline int get_initial_colour(int index, int u, int v, diff --git a/include/feature_generation/feature_generators/wl.hpp b/include/feature_generation/feature_generators/wl.hpp index ca29c43..a6fa6b0 100644 --- a/include/feature_generation/feature_generators/wl.hpp +++ b/include/feature_generation/feature_generators/wl.hpp @@ -28,7 +28,6 @@ namespace feature_generation { Embedding embed(const std::shared_ptr &graph) override; protected: - void init_neighbour_container() override; void collect_impl(const std::vector &graphs) override; void refine(const std::shared_ptr &graph, std::vector &colours, diff --git a/include/feature_generation/features.hpp b/include/feature_generation/features.hpp index f4f7986..3c541aa 100644 --- a/include/feature_generation/features.hpp +++ b/include/feature_generation/features.hpp @@ -94,9 +94,6 @@ namespace feature_generation { // common init for initialisation and loading from file void initialise_variables(); - // neighbour container initialisation depends on feature generator - virtual void init_neighbour_container() = 0; - // main collection body virtual void collect_impl(const std::vector &graphs) = 0; diff --git a/src/feature_generation/feature_generators/kwl2.cpp b/src/feature_generation/feature_generators/kwl2.cpp index 3b78e03..01ee38a 100644 --- a/src/feature_generation/feature_generators/kwl2.cpp +++ b/src/feature_generation/feature_generators/kwl2.cpp @@ -26,10 +26,6 @@ namespace feature_generation { KWL2Features::KWL2Features(const std::string &filename) : Features(filename) {} - void KWL2Features::init_neighbour_container() { - std::cout << "error: KWL2Features neighbour container not implemented yet" << std::endl; - } - int kwl2_pair_to_index_map(int n, int i, int j) { // map pair where 0 <= i, j < n to vec index return i * n + j; diff --git a/src/feature_generation/feature_generators/wl.cpp b/src/feature_generation/feature_generators/wl.cpp index 315545d..9d360c6 100644 --- a/src/feature_generation/feature_generators/wl.cpp +++ b/src/feature_generation/feature_generators/wl.cpp @@ -1,6 +1,5 @@ #include "../../../include/feature_generation/feature_generators/wl.hpp" -#include "../../../include/feature_generation/neighbour_containers/wl_neighbour_container.hpp" #include "../../../include/graph/graph_generator_factory.hpp" #include "../../../include/utils/nlohmann/json.hpp" @@ -28,10 +27,6 @@ namespace feature_generation { WLFeatures::WLFeatures(const std::string &filename) : Features(filename) {} - void WLFeatures::init_neighbour_container() { - neighbour_container = std::make_shared(multiset_hash); - } - void WLFeatures::refine(const std::shared_ptr &graph, std::vector &colours, std::vector &colours_tmp, diff --git a/src/feature_generation/features.cpp b/src/feature_generation/features.cpp index 1737f65..e27d0a0 100644 --- a/src/feature_generation/features.cpp +++ b/src/feature_generation/features.cpp @@ -1,6 +1,7 @@ #include "../../include/feature_generation/features.hpp" #include "../../include/feature_generation/maxsat.hpp" +#include "../../include/feature_generation/neighbour_containers/wl_neighbour_container.hpp" #include "../../include/graph/graph_generator_factory.hpp" #include "../../include/utils/nlohmann/json.hpp" @@ -43,11 +44,18 @@ namespace feature_generation { graph_generator = graph::create_graph_generator(graph_representation, *domain); seen_colour_statistics = std::vector>(2, std::vector(iterations + 1, 0)); - init_neighbour_container(); + + // We use a factory style method here instead of a virtual function as this is called + // from a constructor, from which virtual functions are not allowed to be called. + if (std::set({"wl", "ccwl", "iwl", "niwl"}).count(feature_name)) { + neighbour_container = std::make_shared(multiset_hash); + } else { + std::cout << "error: neighbour container not yet implemented for feature_name=" + << feature_name << std::endl; + } } std::vector> Features::new_layer_to_colours() const { - // plus 1 because zeroth iteration is also included return std::vector>(iterations + 1, std::set()); }