diff --git a/include/feature_generation/features.hpp b/include/feature_generation/features.hpp index 38cd890..6bc786c 100644 --- a/include/feature_generation/features.hpp +++ b/include/feature_generation/features.hpp @@ -80,9 +80,6 @@ namespace feature_generation { // for iteration j = 0, ..., iterations - 1 std::vector> seen_colour_statistics; - // convert states to graphs - std::vector convert_to_graphs(const data::Dataset dataset); - // get hashed colour if it exists, and constructs it if it doesn't int get_colour_hash(const std::vector &colour, const int iteration); @@ -115,6 +112,9 @@ namespace feature_generation { /* Feature generation functions */ + // convert states to graphs + std::vector convert_to_graphs(const data::Dataset dataset); + // collect training colours void collect_from_dataset(const data::Dataset dataset); void collect(const std::vector &graphs); diff --git a/include/feature_generation/pruning_options.hpp b/include/feature_generation/pruning_options.hpp index a7759e1..8ae8271 100644 --- a/include/feature_generation/pruning_options.hpp +++ b/include/feature_generation/pruning_options.hpp @@ -1,5 +1,5 @@ -#ifndef FEATURE_GENERATION_PRUNING_OPTIONS_HPP -#define FEATURE_GENERATION_PRUNING_OPTIONS_HPP +#ifndef FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP +#define FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP #include #include @@ -16,4 +16,4 @@ namespace feature_generation { }; } // namespace feature_generation -#endif // FEATURE_GENERATION_PRUNING_OPTIONS_HPP +#endif // FEATURE_GENERATION_FEATURE_PRUNING_OPTIONS_HPP diff --git a/src/main.cpp b/src/main.cpp index 45b57de..45b2d0b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -457,6 +457,8 @@ py::class_(feature_generation_m, "_Features") "dataset"_a) .def("collect", py::overload_cast &>(&feature_generation::Features::collect), "graphs"_a) + .def("convert_to_graphs", &feature_generation::Features::convert_to_graphs, + "dataset"_a) .def("set_problem", &feature_generation::Features::set_problem, "problem"_a) .def("get_string_representation", py::overload_cast(&feature_generation::Features::get_string_representation), @@ -467,6 +469,8 @@ py::class_(feature_generation_m, "_Features") "dataset"_a) .def("embed", py::overload_cast &>(&feature_generation::Features::embed_graphs), "graphs"_a) + .def("embed", py::overload_cast(&feature_generation::Features::embed_graph), + "graph"_a) .def("embed", py::overload_cast(&feature_generation::Features::embed_state), "state"_a) .def("get_n_features", &feature_generation::Features::get_n_features)