Skip to content

Commit

Permalink
layer pruning for 2-lwl; collapse-layer-x pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
DillonZChen committed Jan 19, 2025
1 parent b1abdd6 commit cb92f3c
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 52 deletions.
1 change: 1 addition & 0 deletions include/feature_generation/feature_generators/wl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace feature_generation {
protected:
void collect_impl(const std::vector<graph::Graph> &graphs) override;
void refine(const std::shared_ptr<graph::Graph> &graph,
std::set<int> &nodes,
std::vector<int> &colours,
std::vector<int> &colours_tmp,
int iteration);
Expand Down
6 changes: 6 additions & 0 deletions include/feature_generation/features.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ namespace feature_generation {

/* Pruning functions */

// output maps equivalent features to the same group
std::map<int, int> get_equivalence_groups(const std::vector<Embedding> &X);
void prune_this_iteration(int iteration,
const std::vector<graph::Graph> &graphs,
std::vector<std::vector<int>> &cur_colours);
Expand Down Expand Up @@ -165,6 +167,10 @@ namespace feature_generation {

/* Util functions */

void log_iteration(int iteration) const {
std::cout << "[Iteration " << iteration << "]\nCollecting." << std::endl;
};

// get string representation of WL colours agnostic to the number of collected colours
std::string get_string_representation(const Embedding &embedding);
std::string get_string_representation(const planning::State &state);
Expand Down
1 change: 0 additions & 1 deletion include/feature_generation/pruning_options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace feature_generation {
static const std::string COLLAPSE_LAYER_X;

static const std::vector<std::string> get_all();
static const bool is_layer_pruning(const std::string &pruning_option);
};
} // namespace feature_generation

Expand Down
2 changes: 2 additions & 0 deletions include/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace graph {
int get_n_nodes() const;
int get_n_edges() const;

std::set<int> get_nodes_set() const;

std::string to_string() const;

// set to false when directly modifying the base graph to prevent excessive memory usage
Expand Down
11 changes: 5 additions & 6 deletions src/feature_generation/feature_generators/ccwl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ namespace feature_generation {
// We use a sum function for the pool operator as described in the ccWL algorithm.
// To change this to max, we just need to replace += occurrences with std::max.

/* 1. Initialise embedding before pruning */
/* 1. Initialise embedding before pruning, and set up memory */
int categorical_size = get_n_features();
Embedding x0(categorical_size * 2, 0);

/* 2. Set up memory for WL updates */
int n_nodes = graph->nodes.size();
std::vector<int> colours(n_nodes);
std::vector<int> colours_tmp(n_nodes);
std::set<int> nodes = graph->get_nodes_set();

/* 3. Compute initial colours */
/* 2. Compute initial colours */
int col;
int is_seen_colour;
for (int node_i = 0; node_i < n_nodes; node_i++) {
Expand All @@ -51,9 +50,9 @@ namespace feature_generation {
}
}

/* 4. Main WL loop */
/* 3. Main WL loop */
for (int itr = 1; itr < iterations + 1; itr++) {
refine(graph, colours, colours_tmp, itr);
refine(graph, nodes, colours, colours_tmp, itr);
for (int node_i = 0; node_i < n_nodes; node_i++) {
col = colours[node_i];
is_seen_colour = (col != UNSEEN_COLOUR); // prevent branch prediction
Expand Down
29 changes: 19 additions & 10 deletions src/feature_generation/feature_generators/lwl2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,17 @@ namespace feature_generation {

void LWL2Features::collect_impl(const std::vector<graph::Graph> &graphs) {
// intermediate graph colours during WL and extra memory for WL updates
std::vector<int> colours;
std::vector<int> colours_tmp;
std::vector<std::vector<int>> graph_colours;
std::vector<std::vector<int>> graph_colours_tmp;

// init colours
log_iteration(0);
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
const auto graph = std::make_shared<graph::Graph>(graphs[graph_i]);
auto edges = graph->edges;
int n_nodes = graph->nodes.size();

int n_pairs = get_n_lwl2_pairs(n_nodes);

// intermediate colours
colours = std::vector<int>(n_pairs, 0);
colours_tmp = std::vector<int>(n_pairs, 0);
std::vector<int> colours(n_pairs, 0);

std::vector<int> pair_to_edge_label = get_lwl2_pair_to_edge_label(graph);
std::vector<std::set<int>> pair_to_neighbours = get_lwl2_pair_to_neighbours(graph);
Expand All @@ -157,10 +155,21 @@ namespace feature_generation {
}
}

// main WL loop
for (int iteration = 1; iteration < iterations + 1; iteration++) {
refine(graph, pair_to_neighbours, colours, colours_tmp, iteration);
graph_colours.push_back(colours);
graph_colours_tmp.push_back(std::vector<int>(n_pairs, 0));
}

// main WL loop
for (int itr = 1; itr < iterations + 1; itr++) {
log_iteration(itr);
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
const auto graph = std::make_shared<graph::Graph>(graphs[graph_i]);
std::vector<std::set<int>> pair_to_neighbours = get_lwl2_pair_to_neighbours(graph);
refine(graph, pair_to_neighbours, graph_colours[graph_i], graph_colours_tmp[graph_i], itr);
}

// layer pruning
prune_this_iteration(itr, graphs, graph_colours);
}
}

Expand Down
30 changes: 21 additions & 9 deletions src/feature_generation/feature_generators/wl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <fstream>
#include <queue>
#include <set>
#include <sstream>

using json = nlohmann::json;
Expand All @@ -28,6 +29,7 @@ namespace feature_generation {
WLFeatures::WLFeatures(const std::string &filename) : Features(filename) {}

void WLFeatures::refine(const std::shared_ptr<graph::Graph> &graph,
std::set<int> &nodes,
std::vector<int> &colours,
std::vector<int> &colours_tmp,
int iteration) {
Expand All @@ -36,11 +38,14 @@ namespace feature_generation {
std::vector<int> neighbour_vector;
int new_colour_compressed;

for (size_t u = 0; u < graph->nodes.size(); u++) {
std::vector<int> nodes_to_discard;

for (const int u : nodes) {
// skip unseen colours
int current_colour = colours[u];
if (current_colour == UNSEEN_COLOUR) {
new_colour_compressed = UNSEEN_COLOUR;
nodes_to_discard.push_back(u);
goto end_of_iteration;
}
neighbour_container->clear();
Expand All @@ -50,6 +55,7 @@ namespace feature_generation {
int neighbour_colour = colours[edge.second];
if (neighbour_colour == UNSEEN_COLOUR) {
new_colour_compressed = UNSEEN_COLOUR;
nodes_to_discard.push_back(u);
goto end_of_iteration;
}

Expand All @@ -70,6 +76,11 @@ namespace feature_generation {
colours_tmp[u] = new_colour_compressed;
}

// discard nodes
for (const int u : nodes_to_discard) {
nodes.erase(u);
}

colours.swap(colours_tmp);
}

Expand All @@ -79,7 +90,7 @@ namespace feature_generation {
std::vector<std::vector<int>> graph_colours_tmp;

// init colours
std::cout << "collecting iteration 0" << std::endl;
log_iteration(0);
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
const auto graph = std::make_shared<graph::Graph>(graphs[graph_i]);
int n_nodes = graph->nodes.size();
Expand All @@ -94,16 +105,16 @@ namespace feature_generation {
}

// main WL loop
for (int iteration = 1; iteration < iterations + 1; iteration++) {
std::cout << "collecting iteration " << iteration << std::endl;

for (int itr = 1; itr < iterations + 1; itr++) {
log_iteration(itr);
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
const auto graph = std::make_shared<graph::Graph>(graphs[graph_i]);
refine(graph, graph_colours[graph_i], graph_colours_tmp[graph_i], iteration);
std::set<int> nodes = graph->get_nodes_set();
refine(graph, nodes, graph_colours[graph_i], graph_colours_tmp[graph_i], itr);
}

// layer pruning
prune_this_iteration(iteration, graphs, graph_colours);
prune_this_iteration(itr, graphs, graph_colours);
}
}

Expand All @@ -118,17 +129,18 @@ namespace feature_generation {
int n_nodes = graph->nodes.size();
std::vector<int> colours(n_nodes);
std::vector<int> colours_tmp(n_nodes);
std::set<int> nodes = graph->get_nodes_set();

/* 2. Compute initial colours */
for (int node_i = 0; node_i < n_nodes; node_i++) {
for (const int node_i : nodes) {
int col = get_colour_hash({graph->nodes[node_i]}, 0);
colours[node_i] = col;
add_colour_to_x(col, 0, x0);
}

/* 3. Main WL loop */
for (int itr = 1; itr < iterations + 1; itr++) {
refine(graph, colours, colours_tmp, itr);
refine(graph, nodes, colours, colours_tmp, itr);
for (const int col : colours) {
add_colour_to_x(col, itr, x0);
}
Expand Down
42 changes: 39 additions & 3 deletions src/feature_generation/features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,16 @@ namespace feature_generation {

void Features::check_valid_configuration() {
// check pruning support
if (pruning == PruningOptions::COLLAPSE_LAYER &&
!std::set<std::string>({"wl"}).count(feature_name)) {
if (std::set<std::string>({
PruningOptions::COLLAPSE_LAYER,
PruningOptions::COLLAPSE_LAYER_X,
})
.count(pruning) &&
!std::set<std::string>({
"wl",
"2-lwl",
})
.count(feature_name)) {
std::cout << "WARNING: pruning option `" << pruning << "`"
<< " not yet supported for feature option `" << feature_name << "`. "
<< "Defaulting to no layer pruning." << std::endl;
Expand Down Expand Up @@ -305,6 +313,8 @@ namespace feature_generation {

collect_impl(graphs);

std::cout << "[complete]" << std::endl;

// bulk pruning
prune_bulk(graphs);
layer_redundancy_check();
Expand All @@ -315,7 +325,7 @@ namespace feature_generation {

// check features have been collected
if (get_n_features() == 0) {
throw std::runtime_error("No features have been collected.");
std::cout << "WARNING: no features have been collected" << std::endl;
}
}

Expand Down Expand Up @@ -364,6 +374,32 @@ namespace feature_generation {
}
}

/* Pruning functions (see pruning/ source files for specific implementations) */

std::map<int, int> Features::get_equivalence_groups(const std::vector<Embedding> &X) {
std::map<int, int> feature_group;
int n_features = X[0].size();
std::unordered_map<std::vector<int>, int, int_vector_hasher> canonical_group;
for (int colour = 0; colour < n_features; colour++) {
std::vector<int> feature;
for (size_t j = 0; j < X.size(); j++) {
feature.push_back(X[j][colour]);
}

int group;
if (canonical_group.count(feature) == 0) { // new feature
group = canonical_group.size();
canonical_group[feature] = group;
} else { // seen this feature before
group = canonical_group.at(feature);
}

feature_group[colour] = group;
}

return feature_group;
}

/* Prediction functions */

double Features::predict(const std::shared_ptr<graph::Graph> &graph) {
Expand Down
24 changes: 4 additions & 20 deletions src/feature_generation/pruning/bulk_pruners.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,13 @@ namespace feature_generation {

// 1. compute equivalent features candidates
std::cout << "Computing equivalent feature candidates." << std::endl;
std::unordered_map<std::vector<int>, int, int_vector_hasher> canonical_group;
std::map<int, int> feature_group = get_equivalence_groups(X);
std::map<int, int> group_size;
std::map<int, int> feature_group;
for (int colour = 0; colour < n_features; colour++) {
std::vector<int> feature;
for (size_t j = 0; j < X.size(); j++) {
feature.push_back(X[j][colour]);
}

int group;
if (canonical_group.count(feature) == 0) { // new feature
group = canonical_group.size();
canonical_group[feature] = group;
} else { // seen this feature before
group = canonical_group.at(feature);
}

for (const auto &[_, group] : feature_group) {
if (group_size.count(group) == 0) {
group_size[group] = 0;
}
group_size.at(group)++;
feature_group[colour] = group;
group_size[group]++;
}

mark_distinct_features(feature_group, group_size);
Expand Down Expand Up @@ -125,8 +110,7 @@ namespace feature_generation {
}

changed += mark_distinct_features(feature_group, group_size);
std::cout << "changed: " << changed << ". candidates: " << feature_group.size()
<< std::endl;
std::cout << "changed: " << changed << ". candidates: " << feature_group.size() << std::endl;
if (changed == 0) {
break;
}
Expand Down
33 changes: 33 additions & 0 deletions src/feature_generation/pruning/layer_pruners.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace feature_generation {
}

if (to_prune.size() != 0) {
std::cout << "Pruning " << to_prune.size() << " features." << std::endl;
std::map<int, int> remap = remap_colour_hash(to_prune);
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
for (size_t node_i = 0; node_i < cur_colours[graph_i].size(); node_i++) {
Expand Down Expand Up @@ -69,7 +70,39 @@ namespace feature_generation {
std::set<int> Features::prune_collapse_layer_x(int iteration,
const std::vector<graph::Graph> &graphs,
std::vector<std::vector<int>> &cur_colours) {
int original_iterations = iterations;
iterations = iteration;
collecting = false;
collected = true;

std::set<int> features_to_prune;
std::vector<Embedding> X = embed_graphs(graphs);
std::map<int, int> feature_group = get_equivalence_groups(X);
std::map<int, std::vector<int>> group_to_features;
for (const auto &[colour, group] : feature_group) {
if (group_to_features.count(group) == 0) {
group_to_features[group] = std::vector<int>();
}
group_to_features[group].push_back(colour);
}
for (const auto &[_, features] : group_to_features) {
if (features.size() == 1) {
continue;
}
bool canonicalised = false;
for (int feature : features) {
if (!canonicalised) {
canonicalised = true;
continue;
} else if (colour_to_layer[feature] == iteration) {
features_to_prune.insert(feature);
}
}
}

collecting = true;
collected = false;
iterations = original_iterations;

return features_to_prune;
}
Expand Down
3 changes: 0 additions & 3 deletions src/feature_generation/pruning_options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,4 @@ namespace feature_generation {
COLLAPSE_LAYER_X,
};
}
const bool PruningOptions::is_layer_pruning(const std::string &pruning_option) {
return pruning_option == COLLAPSE_LAYER || pruning_option == COLLAPSE_LAYER_X;
}
} // namespace feature_generation
Loading

0 comments on commit cb92f3c

Please sign in to comment.