Skip to content

Commit

Permalink
2-wl containers
Browse files Browse the repository at this point in the history
  • Loading branch information
DillonZChen committed Jan 18, 2025
1 parent a0ea3f6 commit 2058a58
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP
#define FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP

#include "../neighbour_container.hpp"

#include <map>
#include <set>
#include <utility>

namespace feature_generation {
class WL2NeighbourContainer : public NeighbourContainer {
public:
WL2NeighbourContainer(bool multiset_hash);

void clear() override;
void insert(const int colour);
void insert(const int colour0, const int colour1) override;
std::vector<int> to_vector() const override;

// pairs are <node_colour, n_occurrence>
std::vector<std::pair<int, int>> deconstruct(const std::vector<int> &colours) const;
std::vector<int> get_neighbour_colours(const std::vector<int> &colours) const override;

std::vector<int> remap(const std::vector<int> &input, const std::map<int, int> &remap) override;

private:
std::set<int> neighbours_set;
std::map<int, int> neighbours_mset;
};
} // namespace feature_generation

#endif // FEATURE_GENERATION_NEIGHBOUR_CONTAINERS_WL2_NEIGHBOUR_CONTAINER_HPP
3 changes: 3 additions & 0 deletions src/feature_generation/features.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "../../include/feature_generation/features.hpp"

#include "../../include/feature_generation/maxsat.hpp"
#include "../../include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp"
#include "../../include/feature_generation/neighbour_containers/wl_neighbour_container.hpp"
#include "../../include/graph/graph_generator_factory.hpp"
#include "../../include/utils/nlohmann/json.hpp"
Expand Down Expand Up @@ -49,6 +50,8 @@ namespace feature_generation {
// from a constructor, from which virtual functions are not allowed to be called.
if (std::set<std::string>({"wl", "ccwl", "iwl", "niwl"}).count(feature_name)) {
neighbour_container = std::make_shared<WLNeighbourContainer>(multiset_hash);
} else if (std::set<std::string>({"2-kwl", "2-lwl"}).count(feature_name)) {
neighbour_container = std::make_shared<WL2NeighbourContainer>(multiset_hash);
} else {
std::cout << "error: neighbour container not yet implemented for feature_name="
<< feature_name << std::endl;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include "../../../include/feature_generation/neighbour_containers/wl2_neighbour_container.hpp"

#include <iostream>

namespace feature_generation {
WL2NeighbourContainer::WL2NeighbourContainer(bool multiset_hash)
: NeighbourContainer(multiset_hash) {}

void WL2NeighbourContainer::clear() {
if (multiset_hash) {
neighbours_mset.clear();
} else {
neighbours_set.clear();
}
}

void WL2NeighbourContainer::insert(const int colour) {
if (multiset_hash) {
if (neighbours_mset.count(colour) > 0)
neighbours_mset[colour]++;
else
neighbours_mset[colour] = 1;
} else {
neighbours_set.insert(colour);
}
}

void WL2NeighbourContainer::insert(const int colour1, const int colour2) {
insert(colour1);
insert(colour2);
}

std::vector<int> WL2NeighbourContainer::to_vector() const {
std::vector<int> vec;
if (multiset_hash) {
for (const auto &[colour, count] : neighbours_mset) {
vec.push_back(colour);
vec.push_back(count);
}
} else {
for (const auto &colour : neighbours_set) {
vec.push_back(colour);
}
}
return vec;
}

std::vector<std::pair<int, int>>
WL2NeighbourContainer::deconstruct(const std::vector<int> &colours) const {
std::vector<std::pair<int, int>> output;

int inc;
if (multiset_hash) {
inc = 2;
} else {
inc = 1;
}

if (colours.size() % inc != 1) {
std::cout << "error: key " << to_string(colours) << " has size() % " << inc
<< " != 1 for multiset_hash=" << multiset_hash << std::endl;
exit(-1);
}

for (size_t i = 1; i < colours.size(); i += inc) {
int node_colour = colours.at(i);
int n_occurrences;
if (multiset_hash) {
n_occurrences = colours.at(i + 1);
} else {
n_occurrences = 1;
}
output.push_back(std::pair<int, int>(node_colour, n_occurrences));
}

return output;
}

std::vector<int>
WL2NeighbourContainer::get_neighbour_colours(const std::vector<int> &colours) const {
std::vector<int> neighbour_colours = {colours.at(0)};
for (const auto &[node_colour, n_occurrences] : deconstruct(colours)) {
neighbour_colours.push_back(node_colour);
}
return neighbour_colours;
}

std::vector<int> WL2NeighbourContainer::remap(const std::vector<int> &input,
const std::map<int, int> &remap) {
clear();

std::vector<int> output = {remap.at(input.at(0))};

for (const auto &[node_colour, n_occurrences] : deconstruct(input)) {
for (int i = 0; i < n_occurrences; i++) {
insert(remap.at(node_colour));
}
}

std::vector<int> vec = to_vector();
output.insert(output.end(), vec.begin(), vec.end());

return output;
}
} // namespace feature_generation

0 comments on commit 2058a58

Please sign in to comment.