diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index d6125a2..8716b30 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -33,10 +33,17 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" + - name: Pass tests + run: | + python -m pip install -r tests/test-requirements.txt + pytest tests/check_not_debug.py + pytest + - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools wheel twine auditwheel pybind11-stubgen + - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} diff --git a/include/feature_generation/features.hpp b/include/feature_generation/features.hpp index 964093a..60dd1b7 100644 --- a/include/feature_generation/features.hpp +++ b/include/feature_generation/features.hpp @@ -70,7 +70,7 @@ namespace feature_generation { std::shared_ptr domain; std::shared_ptr graph_generator; bool collected; - bool collapse_pruned; + bool pruned; bool collecting; int cur_collecting_layer; std::shared_ptr neighbour_container; @@ -130,13 +130,16 @@ namespace feature_generation { /* Pruning functions */ - std::set features_to_prune_this_iteration(int iteration, - std::vector> &cur_colours); - std::set features_to_prune(const std::vector &graphs); + void prune_this_iteration(int iteration, + const std::vector &graphs, + std::vector> &cur_colours); + void prune_bulk(const std::vector &graphs); - std::set greedy_iteration_pruner(int iteration, - std::vector> &cur_colours); - std::set maxsat_bulk_pruner(std::vector X); + std::set prune_collapse_layer(int iteration, std::vector> &cur_colours); + std::set prune_collapse_layer_x(int iteration, + const std::vector &graphs, + std::vector> &cur_colours); + std::set prune_maxsat(std::vector X); /* Prediction functions */ diff --git a/include/feature_generation/pruning_options.hpp b/include/feature_generation/pruning_options.hpp index ce7ec41..ced9f76 100644 --- a/include/feature_generation/pruning_options.hpp +++ b/include/feature_generation/pruning_options.hpp @@ -10,8 +10,10 @@ namespace feature_generation { static const std::string NONE; static const std::string COLLAPSE_ALL; static const std::string COLLAPSE_LAYER; + static const std::string COLLAPSE_LAYER_X; static const std::vector get_all(); + static const bool is_layer_pruning(const std::string &pruning_option); }; } // namespace feature_generation diff --git a/setup.py b/setup.py index 9040199..fe5c6c0 100644 --- a/setup.py +++ b/setup.py @@ -7,18 +7,22 @@ # Read version from wlplan/__version__.py file exec(open("wlplan/__version__.py").read()) +# Compiler flags +COMPILER_FLAGS = [ + # "-O3", + "-DDEBUG", +] + files = [glob("src/*.cpp"), glob("src/**/*.cpp"), glob("src/**/**/*.cpp")] -ext_modules = [ - Pybind11Extension( - "_wlplan", - # Sort input source files if you glob sources to ensure bit-for-bit - # reproducible builds (https://github.com/pybind/python_example/pull/53) - sorted([f for file_group in files for f in file_group]), - # Example: passing in the version to the compiled code - define_macros=[("WLPLAN_VERSION", __version__)], - ), -] +ext_module = Pybind11Extension( + "_wlplan", + # Sort input source files if you glob sources to ensure bit-for-bit + # reproducible builds (https://github.com/pybind/python_example/pull/53) + sorted([f for file_group in files for f in file_group]), + define_macros=[("WLPLAN_VERSION", __version__)], +) +ext_module._add_cflags(COMPILER_FLAGS) setup( name="wlplan", @@ -30,7 +34,7 @@ long_description_content_type="text/markdown", packages=["wlplan", "_wlplan"], package_data={"_wlplan": ["py.typed", "*.pyi", "**/*.pyi"]}, - ext_modules=ext_modules, + ext_modules=[ext_module], cmdclass={"build_ext": build_ext}, project_urls={"GitHub": "https://github.com/DillonZChen/wlplan"}, license="MIT License", diff --git a/src/feature_generation/feature_generators/wl.cpp b/src/feature_generation/feature_generators/wl.cpp index c5af445..0d4f7f5 100644 --- a/src/feature_generation/feature_generators/wl.cpp +++ b/src/feature_generation/feature_generators/wl.cpp @@ -120,27 +120,11 @@ namespace feature_generation { } // layer pruning - std::set to_prune = features_to_prune_this_iteration(iteration, graph_colours); - if (to_prune.size() != 0) { - std::map remap = remap_colour_hash(to_prune); - for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) { - for (size_t node_i = 0; node_i < graph_colours[graph_i].size(); node_i++) { - int col = graph_colours[graph_i][node_i]; - if (remap.count(col) > 0) { - graph_colours[graph_i][node_i] = remap[col]; - } else { - graph_colours[graph_i][node_i] = UNSEEN_COLOUR; - } - } - } - } + prune_this_iteration(iteration, graphs, graph_colours); } // bulk pruning - std::set to_prune = features_to_prune(graphs); - if (to_prune.size() != 0) { - remap_colour_hash(to_prune); - } + prune_bulk(graphs); layer_redundancy_check(); } diff --git a/src/feature_generation/features.cpp b/src/feature_generation/features.cpp index 74b4b9e..632ad59 100644 --- a/src/feature_generation/features.cpp +++ b/src/feature_generation/features.cpp @@ -29,7 +29,7 @@ namespace feature_generation { this->domain = std::make_shared(domain); graph_generator = graph::create_graph_generator(graph_representation, domain); collected = false; - collapse_pruned = false; + pruned = false; collecting = false; neighbour_container = std::make_shared(multiset_hash); seen_colour_statistics = std::vector>(2, std::vector(iterations, 0)); @@ -149,7 +149,9 @@ namespace feature_generation { // make new_colours a copy of colours std::vector new_colours = colours; - // debug_vec(colours); // DEBUG +#ifdef DEBUG + debug_vec(colours); +#endif // colours should always show up in remap by their construction for (const int i : get_neighbour_colour_indices(colours)) { @@ -197,23 +199,25 @@ namespace feature_generation { } ////////////////////////////////////////// - // DEBUG - // std::cout << "initial_colours" << std::endl; - // for (const int i : seen_initial_colours) { - // std::cout << "INITIAL " << i << std::endl; - // } - // std::cout << "old_hash" << std::endl; - // for (const auto &[key, val] : colour_hash) { - // std::cout << "HASH "; debug_hash(key, val); - // } - // std::cout << "to_prune" << std::endl; - // for (const int i : to_prune) { - // std::cout << "PRUNE " << i << std::endl; - // } - // std::cout << "remap" << std::endl; - // for (const auto &[key, val] : remap) { - // std::cout << "REMAP " << key << " -> " << val << std::endl; - // } +#ifdef DEBUG + std::cout << "initial_colours" << std::endl; + for (const int i : seen_initial_colours) { + std::cout << "INITIAL " << i << std::endl; + } + std::cout << "old_hash" << std::endl; + for (const auto &[key, val] : colour_hash) { + std::cout << "HASH "; + debug_hash(key, val); + } + std::cout << "to_prune" << std::endl; + for (const int i : to_prune) { + std::cout << "PRUNE " << i << std::endl; + } + std::cout << "remap" << std::endl; + for (const auto &[key, val] : remap) { + std::cout << "REMAP " << key << " -> " << val << std::endl; + } +#endif ////////////////////////////////////////// // remap keys @@ -267,8 +271,8 @@ namespace feature_generation { } void Features::collect(const std::vector &graphs) { - if (pruning == PruningOptions::COLLAPSE_LAYER && collapse_pruned) { - std::cout << "collect with collapse pruning can only be called at most once" << std::endl; + if (pruning != PruningOptions::NONE && pruned) { + std::cout << "collect with pruning can only be called at most once" << std::endl; exit(-1); } @@ -276,8 +280,8 @@ namespace feature_generation { collect_impl(graphs); - if (pruning == PruningOptions::COLLAPSE_LAYER) { - collapse_pruned = true; + if (pruning == PruningOptions::NONE) { + pruned = true; } collected = true; collecting = false; diff --git a/src/feature_generation/maxsat.cpp b/src/feature_generation/maxsat.cpp index e21e16e..86cf0de 100644 --- a/src/feature_generation/maxsat.cpp +++ b/src/feature_generation/maxsat.cpp @@ -106,7 +106,9 @@ namespace feature_generation { std::string maxsat_wcnf_string = to_string(); - // std::cout << maxsat_wcnf_string << std::endl; // DEBUG +#ifdef DEBUG + std::cout << maxsat_wcnf_string << std::endl; +#endif py::object pysat_rc2 = py::module::import("pysat.examples.rc2").attr("RC2"); py::object pysat_wcnf = py::module::import("pysat.formula").attr("WCNF"); diff --git a/src/feature_generation/pruning/bulk_pruners.cpp b/src/feature_generation/pruning/bulk_pruners.cpp index 72f880e..d90bd5e 100644 --- a/src/feature_generation/pruning/bulk_pruners.cpp +++ b/src/feature_generation/pruning/bulk_pruners.cpp @@ -5,56 +5,72 @@ const int KEEP = -1; namespace feature_generation { - std::set Features::features_to_prune(const std::vector &graphs) { + void Features::prune_bulk(const std::vector &graphs) { + std::set to_prune; if (pruning == PruningOptions::COLLAPSE_ALL) { collected = true; std::vector X = embed_graphs(graphs); - return maxsat_bulk_pruner(X); + to_prune = prune_maxsat(X); } else { - return std::set(); + to_prune = std::set(); } - } - inline void log_feature_info(std::vector feature_group, std::vector group_size) { - int n_distinct_features = 0; - int n_equivalent_features = 0; - int n_equivalence_groups = 0; - for (const int group : feature_group) { - if (group == KEEP) { - n_distinct_features++; - } else if (group_size.at(group) > 1) { - n_equivalent_features++; - } else { // (group_size.at(group) <= 1) - std::cout << "error: equivalence group has a distinct feature" << std::endl; - exit(-1); - } + if (to_prune.size() != 0) { + remap_colour_hash(to_prune); } - for (const int size : group_size) { - if (size > 1) { - n_equivalence_groups++; + } + + inline void log_feature_info(int n_features, + std::set prune_candidates, + std::map feature_group, + std::map group_size) { + int n_candidate_features = prune_candidates.size(); + int n_distinct_features = n_features - n_candidate_features; + int n_equivalence_groups = group_size.size(); + for (const auto &[group, size] : group_size) { + if (size <= 1) { + std::cout << "error: equivalence groups should have size > 1" << std::endl; } } std::cout << "Distinct features: " << n_distinct_features << std::endl; - std::cout << "Equivalent features: " << n_equivalent_features << std::endl; + std::cout << "Prune candidates: " << n_candidate_features << std::endl; std::cout << "Equivalence groups: " << n_equivalence_groups << std::endl; std::cout << "Lower bound features: " << n_distinct_features + n_equivalence_groups << std::endl; } - inline bool mark_distinct_features(std::vector &feature_group, - std::vector &group_size) { - bool changed = false; - for (size_t colour = 0; colour < feature_group.size(); colour++) { + inline int mark_distinct_features(std::set &prune_candidates, + std::map &feature_group, + std::map &group_size) { + int changed = 0; + std::set distinct_groups; + std::set distinct_features; + + // mark groups with size <= 1 and their corresponding feature as distinct + for (const int colour : prune_candidates) { int group = feature_group.at(colour); - if (group >= 0 && group_size.at(group) <= 1) { - feature_group[colour] = KEEP; - changed = true; + if (group >= 0 && group_size.at(group) <= 1) { // feature became distinct for keeping + changed++; + distinct_features.insert(colour); + distinct_groups.insert(group); } } + + // kept features should be erased from feature_group and group_size + for (const int colour : distinct_features) { + feature_group.erase(colour); + prune_candidates.erase(colour); + } + for (const int group : distinct_groups) { + group_size.erase(group); + } + + std::cout << "changed: " << changed << ". candidates: " << prune_candidates.size() << std::endl; + return changed; } - std::set Features::maxsat_bulk_pruner(std::vector X) { + std::set Features::prune_maxsat(std::vector X) { std::cout << "Minimising equivalent features..." << std::endl; // 0. construct feature dependency graph @@ -67,98 +83,108 @@ namespace feature_generation { std::vector indices = get_neighbour_colour_indices(neighbours); for (const int i : indices) { int ancestor = neighbours[i]; - // std::cout << "FDG " << ancestor << " ~> " << colour << std::endl; // DEBUG edges_fw.at(ancestor).push_back(colour); edges_bw.at(colour).push_back(ancestor); +#ifdef DEBUG + std::cout << "FDG " << ancestor << " -> " << colour << std::endl; +#endif } } // 1. compute equivalent features candidates std::cout << "Computing equivalent feature candidates." << std::endl; std::unordered_map, int, int_vector_hasher> canonical_group; - std::vector group_size(n_features, 0); - std::vector feature_group(n_features, 0); + std::map group_size; + std::map feature_group; + std::set prune_candidates; for (int colour = 0; colour < n_features; colour++) { + prune_candidates.insert(colour); + std::vector feature; for (size_t j = 0; j < X.size(); j++) { feature.push_back(X[j][colour]); } + int group; - if (canonical_group.count(feature) == 0) { + if (canonical_group.count(feature) == 0) { // new feature group = canonical_group.size(); canonical_group[feature] = group; } else { // seen this feature before group = canonical_group.at(feature); } + + if (group_size.count(group) == 0) { + group_size[group] = 0; + } group_size.at(group)++; - feature_group.at(colour) = group; + feature_group[colour] = group; } - mark_distinct_features(feature_group, group_size); - log_feature_info(feature_group, group_size); + mark_distinct_features(prune_candidates, feature_group, group_size); + // log_feature_info(n_features, prune_candidates, feature_group, group_size); // 2. mark features that should not be thrown out from highest iteration down std::cout << "Pruning features via dependency graph." << std::endl; int dp_iterations = 0; while (true) { - std::cout << "Pruning DP iteration=" << dp_iterations << std::endl; + std::cout << "DP iteration=" << dp_iterations << ". "; dp_iterations++; + int changed = 0; for (int itr = iterations; itr >= 0; itr--) { for (const int colour : layer_to_colours[itr]) { - if (feature_group.at(colour) != KEEP) { + if (prune_candidates.count(colour)) { continue; } for (const int ancestor_colour : edges_bw.at(colour)) { - int ancestor_group = feature_group.at(ancestor_colour); - feature_group.at(ancestor_colour) = KEEP; - if (ancestor_group == KEEP) { + // mark prune candidates as distinct, so skip non-candidates + if (!prune_candidates.count(ancestor_colour)) { continue; } + int ancestor_group = feature_group.at(ancestor_colour); group_size.at(ancestor_group)--; + prune_candidates.erase(ancestor_colour); + changed++; } } } - bool changed = mark_distinct_features(feature_group, group_size); - if (!changed) { + changed += mark_distinct_features(prune_candidates, feature_group, group_size); + if (changed == 0) { break; } } - log_feature_info(feature_group, group_size); + // log_feature_info(n_features, prune_candidates, feature_group, group_size); // 3. maxsat std::cout << "Encoding MaxSAT." << std::endl; // get groups std::map> group_to_features; - std::set variables; - for (int colour = 0; colour < n_features; colour++) { + for (const int colour : prune_candidates) { int group = feature_group.at(colour); - if (group != KEEP) { - if (group_to_features.count(group) == 0) { - group_to_features[group] = std::vector(); - } - group_to_features[group].push_back(colour); - variables.insert(colour); + if (group_to_features.count(group) == 0) { + group_to_features[group] = std::vector(); } + group_to_features[group].push_back(colour); } std::vector clauses; // variable=T indicates feature to be thrown out // equivalently, ~variable=T indicates feature to be kept - for (const int variable : variables) { - clauses.push_back(MaxSatClause({variable}, {false}, 1, false)); + for (const int colour : prune_candidates) { + clauses.push_back(MaxSatClause({colour}, {false}, 1, false)); } // a thrown out variable forces children to be thrown out // i.e., variable => child_1 & ... & child_n which is equivalent to // (~variable | child_1) & ... & (~variable | child_n) - for (const int ancestor : variables) { + for (const int ancestor : prune_candidates) { for (const int child : edges_fw.at(ancestor)) { - if (variables.count(child) == 0) { - continue; + if (!prune_candidates.count(child)) { + std::cout << "error: child of prune candidate is not a candidate" << std::endl; + exit(-1); } clauses.push_back(MaxSatClause({ancestor, child}, {true, false}, 0, true)); } diff --git a/src/feature_generation/pruning/layer_pruners.cpp b/src/feature_generation/pruning/layer_pruners.cpp index 93cb37f..150f58e 100644 --- a/src/feature_generation/pruning/layer_pruners.cpp +++ b/src/feature_generation/pruning/layer_pruners.cpp @@ -2,18 +2,35 @@ namespace feature_generation { - std::set - Features::features_to_prune_this_iteration(int iteration, - std::vector> &cur_colours) { + void Features::prune_this_iteration(int iteration, + const std::vector &graphs, + std::vector> &cur_colours) { + std::set to_prune; if (pruning == PruningOptions::COLLAPSE_LAYER) { - return greedy_iteration_pruner(iteration, cur_colours); + to_prune = prune_collapse_layer(iteration, cur_colours); + } else if (pruning == PruningOptions::COLLAPSE_LAYER_X) { + to_prune = prune_collapse_layer_x(iteration, graphs, cur_colours); } else { - return std::set(); + to_prune = std::set(); + } + + if (to_prune.size() != 0) { + std::map remap = remap_colour_hash(to_prune); + for (size_t graph_i = 0; graph_i < graphs.size(); graph_i++) { + for (size_t node_i = 0; node_i < cur_colours[graph_i].size(); node_i++) { + int col = cur_colours[graph_i][node_i]; + if (remap.count(col) > 0) { + cur_colours[graph_i][node_i] = remap[col]; + } else { + cur_colours[graph_i][node_i] = UNSEEN_COLOUR; + } + } + } } } - std::set Features::greedy_iteration_pruner(int iteration, - std::vector> &cur_colours) { + std::set Features::prune_collapse_layer(int iteration, + std::vector> &cur_colours) { std::set colours = get_iteration_colours(iteration); std::set features_to_prune; @@ -48,4 +65,12 @@ namespace feature_generation { return features_to_prune; } + + std::set Features::prune_collapse_layer_x(int iteration, + const std::vector &graphs, + std::vector> &cur_colours) { + std::set features_to_prune; + + return features_to_prune; + } } // namespace feature_generation diff --git a/src/feature_generation/pruning_options.cpp b/src/feature_generation/pruning_options.cpp index 1c6d688..1eb1a07 100644 --- a/src/feature_generation/pruning_options.cpp +++ b/src/feature_generation/pruning_options.cpp @@ -4,11 +4,16 @@ namespace feature_generation { const std::string PruningOptions::NONE = "none"; const std::string PruningOptions::COLLAPSE_ALL = "collapse-all"; const std::string PruningOptions::COLLAPSE_LAYER = "collapse-layer"; + const std::string PruningOptions::COLLAPSE_LAYER_X = "collapse-layer-x"; const std::vector PruningOptions::get_all() { return { NONE, COLLAPSE_ALL, COLLAPSE_LAYER, + COLLAPSE_LAYER_X, }; } + const bool PruningOptions::is_layer_pruning(const std::string &pruning_option) { + return pruning_option == COLLAPSE_LAYER || pruning_option == COLLAPSE_LAYER_X; + } } // namespace feature_generation diff --git a/tests/check_not_debug.py b/tests/check_not_debug.py new file mode 100644 index 0000000..2b5be34 --- /dev/null +++ b/tests/check_not_debug.py @@ -0,0 +1,19 @@ +import os + + +def test_not_debug(): + CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + ROOT_DIR = os.path.normpath(os.path.join(CUR_DIR, "..")) + + SETUP_SCRIPT = os.path.join(ROOT_DIR, "setup.py") + with open(SETUP_SCRIPT, "r") as f: + setup_script = f.read() + + found_debug = False + for line in setup_script.split("\n"): + toks = line.split("#") + if len(toks) > 0 and "-DDEBUG" in toks[0]: + found_debug = True + break + + assert not found_debug, "Found -DDEBUG in setup.py" diff --git a/tests/colours.py b/tests/colours.py index 4831c1e..5bfd3ce 100644 --- a/tests/colours.py +++ b/tests/colours.py @@ -19,9 +19,10 @@ # "spanner": 350, # "transport": 3787, } +DOMAINS = sorted(FD_COLOURS.keys()) -def colours_test(domain_name, iterations, Class): +def colours_test(domain_name: str, iterations: int, Class, pruning: str = None): logging.info(f"L={iterations}") n_features = {} @@ -34,13 +35,16 @@ def colours_test(domain_name, iterations, Class): } for desc, config in configs.items(): - domain, dataset, _ = get_dataset(domain_name, keep_statics=config["keep_statics"]) + keep_statics = config["keep_statics"] + multiset_hash = config["multiset_hash"] + logging.info(f"{keep_statics=}, {multiset_hash=}") + domain, dataset, _ = get_dataset(domain_name, keep_statics=keep_statics) feature_generator = Class( domain=domain, graph_representation="ilg", iterations=iterations, - pruning=None, - multiset_hash=config["multiset_hash"], + pruning=pruning, + multiset_hash=multiset_hash, ) feature_generator.collect(dataset) X = np.array(feature_generator.embed(dataset)).astype(float) diff --git a/tests/test_iwl.py b/tests/test_iwl.py index 7bcbec1..5d39efe 100644 --- a/tests/test_iwl.py +++ b/tests/test_iwl.py @@ -1,13 +1,13 @@ import logging import pytest -from colours import FD_COLOURS, colours_test +from colours import DOMAINS, colours_test from wlplan.feature_generation import IWLFeatures LOGGER = logging.getLogger(__name__) -@pytest.mark.parametrize("domain_name", sorted(FD_COLOURS.keys())) +@pytest.mark.parametrize("domain_name", DOMAINS) def test_domain(domain_name): colours_test(domain_name, 2, IWLFeatures) diff --git a/tests/test_kwl2.py b/tests/test_kwl2.py index a585a1c..5a2c318 100644 --- a/tests/test_kwl2.py +++ b/tests/test_kwl2.py @@ -1,14 +1,14 @@ import logging import pytest -from colours import FD_COLOURS, colours_test +from colours import DOMAINS, colours_test from wlplan.feature_generation import KWL2Features LOGGER = logging.getLogger(__name__) -@pytest.mark.parametrize("domain_name", sorted(FD_COLOURS.keys())) +@pytest.mark.parametrize("domain_name", DOMAINS) def test_domain(domain_name): logging.info("Skipped because too memory intensive") # colours_test(domain_name, 2, KWL2Features) diff --git a/tests/test_lwl2.py b/tests/test_lwl2.py index 1175edb..20aa8c6 100644 --- a/tests/test_lwl2.py +++ b/tests/test_lwl2.py @@ -1,13 +1,13 @@ import logging import pytest -from colours import FD_COLOURS, colours_test +from colours import DOMAINS, colours_test from wlplan.feature_generation import LWL2Features LOGGER = logging.getLogger(__name__) -@pytest.mark.parametrize("domain_name", sorted(FD_COLOURS.keys())) +@pytest.mark.parametrize("domain_name", DOMAINS) def test_domain(domain_name): colours_test(domain_name, 2, LWL2Features) diff --git a/tests/test_niwl.py b/tests/test_niwl.py index 1c285e8..e37a759 100644 --- a/tests/test_niwl.py +++ b/tests/test_niwl.py @@ -1,13 +1,13 @@ import logging import pytest -from colours import FD_COLOURS, colours_test +from colours import DOMAINS, colours_test from wlplan.feature_generation import NIWLFeatures LOGGER = logging.getLogger(__name__) -@pytest.mark.parametrize("domain_name", sorted(FD_COLOURS.keys())) +@pytest.mark.parametrize("domain_name", DOMAINS) def test_domain(domain_name): colours_test(domain_name, 2, NIWLFeatures) diff --git a/tests/test_pruning.py b/tests/test_pruning.py new file mode 100644 index 0000000..6f6a5c7 --- /dev/null +++ b/tests/test_pruning.py @@ -0,0 +1,16 @@ +import logging +from itertools import product + +import pytest +from colours import DOMAINS, colours_test + +from wlplan.feature_generation import PruningOptions, WLFeatures + +LOGGER = logging.getLogger(__name__) + + +@pytest.mark.parametrize("domain_name,pruning", product(DOMAINS, PruningOptions.get_all())) +def test_domain(domain_name, pruning): + if pruning == PruningOptions.NONE: + pytest.skip() + colours_test(domain_name, 4, WLFeatures, pruning) diff --git a/tests/test_wl.py b/tests/test_wl.py index d7cc17b..17387d6 100644 --- a/tests/test_wl.py +++ b/tests/test_wl.py @@ -1,13 +1,13 @@ import logging import pytest -from colours import FD_COLOURS, colours_test +from colours import DOMAINS, colours_test from wlplan.feature_generation import WLFeatures LOGGER = logging.getLogger(__name__) -@pytest.mark.parametrize("domain_name", sorted(FD_COLOURS.keys())) +@pytest.mark.parametrize("domain_name", DOMAINS) def test_domain(domain_name): colours_test(domain_name, 4, WLFeatures)