Skip to content

Commit 44c5e17

Browse files
committed
minor refactoring
1 parent cb92f3c commit 44c5e17

File tree

20 files changed

+90
-137
lines changed

20 files changed

+90
-137
lines changed

include/feature_generation/feature_generators/ccwl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace feature_generation {
2121

2222
CCWLFeatures(const std::string &filename);
2323

24-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
24+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
2525

2626
void set_weights(const std::vector<double> &weights);
2727
};

include/feature_generation/feature_generators/iwl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@ namespace feature_generation {
2828

2929
IWLFeatures(const std::string &filename);
3030

31-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
31+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
3232

3333
protected:
3434
void collect_impl(const std::vector<graph::Graph> &graphs) override;
3535
void refine(const std::shared_ptr<graph::Graph> &graph,
3636
std::vector<int> &colours,
37-
std::vector<int> &colours_tmp,
3837
int iteration);
3938
};
4039
} // namespace feature_generation

include/feature_generation/feature_generators/kwl2.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace feature_generation {
2828

2929
KWL2Features(const std::string &filename);
3030

31-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
31+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
3232

3333
protected:
3434
inline int get_initial_colour(int index,
@@ -39,7 +39,6 @@ namespace feature_generation {
3939
void collect_impl(const std::vector<graph::Graph> &graphs) override;
4040
void refine(const std::shared_ptr<graph::Graph> &graph,
4141
std::vector<int> &colours,
42-
std::vector<int> &colours_tmp,
4342
int iteration);
4443
};
4544
} // namespace feature_generation

include/feature_generation/feature_generators/lwl2.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace feature_generation {
2121

2222
LWL2Features(const std::string &filename);
2323

24-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
24+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
2525

2626
protected:
2727
inline int get_initial_colour(int index,
@@ -33,7 +33,6 @@ namespace feature_generation {
3333
void refine(const std::shared_ptr<graph::Graph> &graph,
3434
std::vector<std::set<int>> &pair_to_neighbours,
3535
std::vector<int> &colours,
36-
std::vector<int> &colours_tmp,
3736
int iteration);
3837
};
3938
} // namespace feature_generation

include/feature_generation/feature_generators/niwl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace feature_generation {
1818

1919
NIWLFeatures(const std::string &filename);
2020

21-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
21+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
2222
};
2323
} // namespace feature_generation
2424

include/feature_generation/feature_generators/wl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ namespace feature_generation {
2525

2626
WLFeatures(const std::string &filename);
2727

28-
Embedding embed(const std::shared_ptr<graph::Graph> &graph) override;
28+
Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) override;
2929

3030
protected:
3131
void collect_impl(const std::vector<graph::Graph> &graphs) override;
3232
void refine(const std::shared_ptr<graph::Graph> &graph,
3333
std::set<int> &nodes,
3434
std::vector<int> &colours,
35-
std::vector<int> &colours_tmp,
3635
int iteration);
3736
};
3837
} // namespace feature_generation

include/feature_generation/features.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ namespace feature_generation {
9797
// common init for initialisation and loading from file
9898
void initialise_variables();
9999

100-
// main collection body
100+
// main virtual functions
101101
virtual void collect_impl(const std::vector<graph::Graph> &graphs) = 0;
102+
virtual Embedding embed_impl(const std::shared_ptr<graph::Graph> &graph) = 0;
102103

103104
public:
104105
Features(const std::string feature_name,
@@ -124,7 +125,8 @@ namespace feature_generation {
124125
std::vector<Embedding> embed_graphs(const std::vector<graph::Graph> &graphs);
125126
Embedding embed_graph(const graph::Graph &graph);
126127
Embedding embed_state(const planning::State &state);
127-
virtual Embedding embed(const std::shared_ptr<graph::Graph> &graph) = 0;
128+
Embedding embed(const std::shared_ptr<graph::Graph> &graph);
129+
128130

129131
void add_colour_to_x(int colour, int iteration, Embedding &x);
130132

src/feature_generation/feature_generator_loader.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ std::shared_ptr<feature_generation::Features> load_feature_generator(const std::
2828
} else if (feature_name == "niwl") {
2929
feature_generator = std::make_shared<feature_generation::NIWLFeatures>(save_file);
3030
} else {
31-
std::cout << "Feature name " << feature_name << " not recognised. Exiting." << std::endl;
32-
exit(-1);
31+
throw std::runtime_error("Feature name " + feature_name + " not recognised.");
3332
}
3433
std::cout << "Feature generator loaded!" << std::endl;
3534
return feature_generator;

src/feature_generation/feature_generators/ccwl.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@ namespace feature_generation {
1818

1919
CCWLFeatures::CCWLFeatures(const std::string &filename) : WLFeatures(filename) {}
2020

21-
Embedding CCWLFeatures::embed(const std::shared_ptr<graph::Graph> &graph) {
22-
collecting = false;
23-
if (!collected) {
24-
throw std::runtime_error("CCWLFeatures::collect() must be called before embedding");
25-
}
26-
21+
Embedding CCWLFeatures::embed_impl(const std::shared_ptr<graph::Graph> &graph) {
2722
// New additions to the WL algorithm are indicated with the [NUMERIC] comments.
2823
// We use a sum function for the pool operator as described in the ccWL algorithm.
2924
// To change this to max, we just need to replace += occurrences with std::max.
@@ -33,7 +28,6 @@ namespace feature_generation {
3328
Embedding x0(categorical_size * 2, 0);
3429
int n_nodes = graph->nodes.size();
3530
std::vector<int> colours(n_nodes);
36-
std::vector<int> colours_tmp(n_nodes);
3731
std::set<int> nodes = graph->get_nodes_set();
3832

3933
/* 2. Compute initial colours */
@@ -52,7 +46,7 @@ namespace feature_generation {
5246

5347
/* 3. Main WL loop */
5448
for (int itr = 1; itr < iterations + 1; itr++) {
55-
refine(graph, nodes, colours, colours_tmp, itr);
49+
refine(graph, nodes, colours, itr);
5650
for (int node_i = 0; node_i < n_nodes; node_i++) {
5751
col = colours[node_i];
5852
is_seen_colour = (col != UNSEEN_COLOUR); // prevent branch prediction

src/feature_generation/feature_generators/iwl.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ namespace feature_generation {
2929

3030
void IWLFeatures::refine(const std::shared_ptr<graph::Graph> &graph,
3131
std::vector<int> &colours,
32-
std::vector<int> &colours_tmp,
3332
int iteration) {
3433
// memory for storing string and hashed int representation of colours
3534
std::vector<int> new_colour;
3635
std::vector<int> neighbour_vector;
3736
int new_colour_compressed;
3837

38+
std::vector<int> new_colours(colours.size(), UNSEEN_COLOUR);
39+
3940
for (size_t u = 0; u < graph->nodes.size(); u++) {
4041
// skip unseen colours
4142
if (colours[u] == UNSEEN_COLOUR) {
@@ -65,16 +66,15 @@ namespace feature_generation {
6566
new_colour_compressed = get_colour_hash(new_colour, iteration);
6667

6768
end_of_iteration:
68-
colours_tmp[u] = new_colour_compressed;
69+
new_colours[u] = new_colour_compressed;
6970
}
7071

71-
colours.swap(colours_tmp);
72+
colours = new_colours;
7273
}
7374

7475
void IWLFeatures::collect_impl(const std::vector<graph::Graph> &graphs) {
75-
// intermediate graph colours during WL and extra memory for WL updates
76+
// intermediate graph colours during WL
7677
std::vector<int> colours;
77-
std::vector<int> colours_tmp;
7878

7979
// init colours
8080
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
@@ -84,7 +84,6 @@ namespace feature_generation {
8484
// individualisation for each node
8585
for (int node_i = 0; node_i < n_nodes; node_i++) {
8686
colours = std::vector<int>(n_nodes, 0);
87-
colours_tmp = std::vector<int>(n_nodes, 0);
8887

8988
for (int u = 0; u < n_nodes; u++) {
9089
std::vector<int> colour_key = {graph->nodes[u]};
@@ -97,30 +96,22 @@ namespace feature_generation {
9796

9897
// main WL loop
9998
for (int iteration = 1; iteration < iterations + 1; iteration++) {
100-
refine(graph, colours, colours_tmp, iteration);
99+
refine(graph, colours, iteration);
101100
}
102101
}
103102
}
104103
}
105104

106-
Embedding IWLFeatures::embed(const std::shared_ptr<graph::Graph> &graph) {
107-
collecting = false;
108-
if (!collected) {
109-
throw std::runtime_error("IWLFeatures::collect() must be called before embedding");
110-
}
111-
105+
Embedding IWLFeatures::embed_impl(const std::shared_ptr<graph::Graph> &graph) {
112106
/* 1. Initialise embedding */
113107
Embedding x0(get_n_features(), 0);
114-
115-
/* 2. Set up memory for WL updates */
116108
int n_nodes = graph->nodes.size();
117109

118110
/* Individualisation */
119111
for (int node_i = 0; node_i < n_nodes; node_i++) {
120112
std::vector<int> colours(n_nodes);
121-
std::vector<int> colours_tmp(n_nodes);
122113

123-
/* 3. Compute initial colours */
114+
/* 2. Compute initial colours */
124115
for (int u = 0; u < n_nodes; u++) {
125116
std::vector<int> colour_key = {graph->nodes[u]};
126117
if (u == node_i) {
@@ -130,9 +121,9 @@ namespace feature_generation {
130121
add_colour_to_x(col, 0, x0);
131122
}
132123

133-
/* 4. Main WL loop */
124+
/* 3. Main WL loop */
134125
for (int itr = 1; itr < iterations + 1; itr++) {
135-
refine(graph, colours, colours_tmp, itr);
126+
refine(graph, colours, itr);
136127
for (const int col : colours) {
137128
add_colour_to_x(col, itr, x0);
138129
}

src/feature_generation/feature_generators/kwl2.cpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ namespace feature_generation {
3535

3636
void KWL2Features::refine(const std::shared_ptr<graph::Graph> &graph,
3737
std::vector<int> &colours,
38-
std::vector<int> &colours_tmp,
3938
int iteration) {
4039
// memory for storing string and hashed int representation of colours
4140
std::vector<int> new_colour;
4241
std::vector<int> neighbour_vector;
4342
int new_colour_compressed, pair1, pair2, pair1_col, pair2_col;
4443
int n_nodes = graph->nodes.size();
4544

45+
std::vector<int> new_colours(colours.size(), UNSEEN_COLOUR);
46+
4647
for (int u = 0; u < n_nodes; u++) {
4748
for (int v = 0; v < n_nodes; v++) {
4849
int index = kwl2_pair_to_index_map(n_nodes, u, v);
@@ -76,11 +77,11 @@ namespace feature_generation {
7677
new_colour_compressed = get_colour_hash(new_colour, iteration);
7778

7879
end_of_iteration:
79-
colours_tmp[index] = new_colour_compressed;
80+
new_colours[index] = new_colour_compressed;
8081
}
8182
}
8283

83-
colours.swap(colours_tmp);
84+
colours = new_colours;
8485
}
8586

8687
std::vector<int> get_kwl2_pair_to_edge_label(std::shared_ptr<graph::Graph> graph) {
@@ -109,9 +110,8 @@ namespace feature_generation {
109110
}
110111

111112
void KWL2Features::collect_impl(const std::vector<graph::Graph> &graphs) {
112-
// intermediate graph colours during WL and extra memory for WL updates
113+
// intermediate graph colours during WL
113114
std::vector<int> colours;
114-
std::vector<int> colours_tmp;
115115

116116
for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) {
117117
const auto graph = std::make_shared<graph::Graph>(graphs[graph_i]);
@@ -122,7 +122,6 @@ namespace feature_generation {
122122

123123
// intermediate colours
124124
colours = std::vector<int>(n_pairs, 0);
125-
colours_tmp = std::vector<int>(n_pairs, 0);
126125

127126
std::vector<int> pair_to_edge_label = get_kwl2_pair_to_edge_label(graph);
128127

@@ -137,29 +136,22 @@ namespace feature_generation {
137136

138137
// main WL loop
139138
for (int iteration = 1; iteration < iterations + 1; iteration++) {
140-
refine(graph, colours, colours_tmp, iteration);
139+
refine(graph, colours, iteration);
141140
}
142141
}
143142
}
144143

145-
Embedding KWL2Features::embed(const std::shared_ptr<graph::Graph> &graph) {
146-
collecting = false;
147-
if (!collected) {
148-
throw std::runtime_error("KWL2Features::collect() must be called before embedding");
149-
}
150-
144+
Embedding KWL2Features::embed_impl(const std::shared_ptr<graph::Graph> &graph) {
151145
/* 1. Initialise embedding before pruning */
152146
Embedding x0(get_n_features(), 0);
153147

154-
/* 2. Set up memory for WL updates */
155148
int n_nodes = graph->nodes.size();
156149
int n_pairs = get_n_kwl2_pairs(n_nodes);
157150
std::vector<int> colours(n_pairs);
158-
std::vector<int> colours_tmp(n_pairs);
159151

160152
std::vector<int> pair_to_edge_label = get_kwl2_pair_to_edge_label(graph);
161153

162-
/* 3. Compute initial colours */
154+
/* 2. Compute initial colours */
163155
for (int u = 0; u < n_nodes; u++) {
164156
for (int v = 0; v < n_nodes; v++) {
165157
int index = kwl2_pair_to_index_map(n_nodes, u, v);
@@ -169,9 +161,9 @@ namespace feature_generation {
169161
}
170162
}
171163

172-
/* 4. Main WL loop */
164+
/* 3. Main WL loop */
173165
for (int itr = 1; itr < iterations + 1; itr++) {
174-
refine(graph, colours, colours_tmp, itr);
166+
refine(graph, colours, itr);
175167
for (const int col : colours) {
176168
add_colour_to_x(col, itr, x0);
177169
}

0 commit comments

Comments
 (0)