Skip to content

Commit

Permalink
Bind more functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DillonZChen committed Jan 20, 2025
1 parent 44c5e17 commit 0fd5f25
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
6 changes: 3 additions & 3 deletions include/feature_generation/features.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ namespace feature_generation {
// for iteration j = 0, ..., iterations - 1
std::vector<std::vector<long>> seen_colour_statistics;

// convert states to graphs
std::vector<graph::Graph> convert_to_graphs(const data::Dataset dataset);

// get hashed colour if it exists, and constructs it if it doesn't
int get_colour_hash(const std::vector<int> &colour, const int iteration);

Expand Down Expand Up @@ -115,6 +112,9 @@ namespace feature_generation {

/* Feature generation functions */

// convert states to graphs
std::vector<graph::Graph> convert_to_graphs(const data::Dataset dataset);

// collect training colours
void collect_from_dataset(const data::Dataset dataset);
void collect(const std::vector<graph::Graph> &graphs);
Expand Down
6 changes: 3 additions & 3 deletions include/feature_generation/pruning_options.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef FEATURE_GENERATION_PRUNING_OPTIONS_HPP
#define FEATURE_GENERATION_PRUNING_OPTIONS_HPP
#ifndef FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP
#define FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP

#include <string>
#include <vector>
Expand All @@ -16,4 +16,4 @@ namespace feature_generation {
};
} // namespace feature_generation

#endif // FEATURE_GENERATION_PRUNING_OPTIONS_HPP
#endif // FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP
4 changes: 4 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ py::class_<feature_generation::Features>(feature_generation_m, "_Features")
"dataset"_a)
.def("collect", py::overload_cast<const std::vector<graph::Graph> &>(&feature_generation::Features::collect),
"graphs"_a)
.def("convert_to_graphs", &feature_generation::Features::convert_to_graphs,
"dataset"_a)
.def("set_problem", &feature_generation::Features::set_problem,
"problem"_a)
.def("get_string_representation", py::overload_cast<const feature_generation::Embedding &>(&feature_generation::Features::get_string_representation),
Expand All @@ -467,6 +469,8 @@ py::class_<feature_generation::Features>(feature_generation_m, "_Features")
"dataset"_a)
.def("embed", py::overload_cast<const std::vector<graph::Graph> &>(&feature_generation::Features::embed_graphs),
"graphs"_a)
.def("embed", py::overload_cast<const graph::Graph &>(&feature_generation::Features::embed_graph),
"graph"_a)
.def("embed", py::overload_cast<const planning::State &>(&feature_generation::Features::embed_state),
"state"_a)
.def("get_n_features", &feature_generation::Features::get_n_features)
Expand Down

0 comments on commit 0fd5f25

Please sign in to comment.