From 2058a585027ea02b4cc76485d0e0a28001194ce6 Mon Sep 17 00:00:00 2001 From: Dillon Date: Sat, 18 Jan 2025 21:08:28 +1100 Subject: [PATCH] 2-wl containers --- .../wl2_neighbour_container.hpp | 32 ++++++ src/feature_generation/features.cpp | 3 + .../wl2_neighbour_container.cpp | 105 ++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp create mode 100644 src/feature_generation/neighbour_containers/wl2_neighbour_container.cpp diff --git a/include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp b/include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp new file mode 100644 index 0000000..8b12f9a --- /dev/null +++ b/include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp @@ -0,0 +1,32 @@ +#ifndef FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP +#define FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP + +#include "../neighbour_container.hpp" + +#include +#include +#include + +namespace feature_generation { + class WL2NeighbourContainer : public NeighbourContainer { + public: + WL2NeighbourContainer(bool multiset_hash); + + void clear() override; + void insert(const int colour); + void insert(const int colour0, const int colour1) override; + std::vector to_vector() const override; + + // pairs are + std::vector> deconstruct(const std::vector &colours) const; + std::vector get_neighbour_colours(const std::vector &colours) const override; + + std::vector remap(const std::vector &input, const std::map &remap) override; + + private: + std::set neighbours_set; + std::map neighbours_mset; + }; +} // namespace feature_generation + +#endif // FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP diff --git a/src/feature_generation/features.cpp b/src/feature_generation/features.cpp index e27d0a0..fd2c9bf 100644 --- a/src/feature_generation/features.cpp +++ b/src/feature_generation/features.cpp @@ -1,6 +1,7 @@ #include "../../include/feature_generation/features.hpp" #include "../../include/feature_generation/maxsat.hpp" +#include "../../include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp" #include "../../include/feature_generation/neighbour_containers/wl_neighbour_container.hpp" #include "../../include/graph/graph_generator_factory.hpp" #include "../../include/utils/nlohmann/json.hpp" @@ -49,6 +50,8 @@ namespace feature_generation { // from a constructor, from which virtual functions are not allowed to be called. if (std::set({"wl", "ccwl", "iwl", "niwl"}).count(feature_name)) { neighbour_container = std::make_shared(multiset_hash); + } else if (std::set({"2-kwl", "2-lwl"}).count(feature_name)) { + neighbour_container = std::make_shared(multiset_hash); } else { std::cout << "error: neighbour container not yet implemented for feature_name=" << feature_name << std::endl; diff --git a/src/feature_generation/neighbour_containers/wl2_neighbour_container.cpp b/src/feature_generation/neighbour_containers/wl2_neighbour_container.cpp new file mode 100644 index 0000000..94c6680 --- /dev/null +++ b/src/feature_generation/neighbour_containers/wl2_neighbour_container.cpp @@ -0,0 +1,105 @@ +#include "../../../include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp" + +#include + +namespace feature_generation { + WL2NeighbourContainer::WL2NeighbourContainer(bool multiset_hash) + : NeighbourContainer(multiset_hash) {} + + void WL2NeighbourContainer::clear() { + if (multiset_hash) { + neighbours_mset.clear(); + } else { + neighbours_set.clear(); + } + } + + void WL2NeighbourContainer::insert(const int colour) { + if (multiset_hash) { + if (neighbours_mset.count(colour) > 0) + neighbours_mset[colour]++; + else + neighbours_mset[colour] = 1; + } else { + neighbours_set.insert(colour); + } + } + + void WL2NeighbourContainer::insert(const int colour1, const int colour2) { + insert(colour1); + insert(colour2); + } + + std::vector WL2NeighbourContainer::to_vector() const { + std::vector vec; + if (multiset_hash) { + for (const auto &[colour, count] : neighbours_mset) { + vec.push_back(colour); + vec.push_back(count); + } + } else { + for (const auto &colour : neighbours_set) { + vec.push_back(colour); + } + } + return vec; + } + + std::vector> + WL2NeighbourContainer::deconstruct(const std::vector &colours) const { + std::vector> output; + + int inc; + if (multiset_hash) { + inc = 2; + } else { + inc = 1; + } + + if (colours.size() % inc != 1) { + std::cout << "error: key " << to_string(colours) << " has size() % " << inc + << " != 1 for multiset_hash=" << multiset_hash << std::endl; + exit(-1); + } + + for (size_t i = 1; i < colours.size(); i += inc) { + int node_colour = colours.at(i); + int n_occurrences; + if (multiset_hash) { + n_occurrences = colours.at(i + 1); + } else { + n_occurrences = 1; + } + output.push_back(std::pair(node_colour, n_occurrences)); + } + + return output; + } + + std::vector + WL2NeighbourContainer::get_neighbour_colours(const std::vector &colours) const { + std::vector neighbour_colours = {colours.at(0)}; + for (const auto &[node_colour, n_occurrences] : deconstruct(colours)) { + neighbour_colours.push_back(node_colour); + } + return neighbour_colours; + } + + std::vector WL2NeighbourContainer::remap(const std::vector &input, + const std::map &remap) { + clear(); + + std::vector output = {remap.at(input.at(0))}; + + for (const auto &[node_colour, n_occurrences] : deconstruct(input)) { + for (int i = 0; i < n_occurrences; i++) { + insert(remap.at(node_colour)); + } + } + + std::vector vec = to_vector(); + output.insert(output.end(), vec.begin(), vec.end()); + + return output; + } +} // namespace feature_generation