From ee5f95eb04d52a316ae7424e36f3cce64ab1c230 Mon Sep 17 00:00:00 2001 From: Eduardo Arias Date: Fri, 9 Aug 2024 06:54:35 -0700 Subject: [PATCH] Added support to run unit tests in a multithreaded context - This is controlled by specifying the 'mtstress' argument when running `unit_test`. - The goal is to detect if the operator/transformation fails in this context. - In this mode, the test will be executed 5'000 times in 50 threads concurrently. - Allocation & initialization of the operator/transformation is performed once in the main thread, while the evaluation is executed in the threads. - This is consistent with the library's support for multithreading, where initialization and loading of rules is expected to run once. See issue #3215. --- test/common/modsecurity_test.cc | 19 ++-- test/common/modsecurity_test.h | 8 +- test/unit/unit.cc | 190 +++++++++++++++++++++++--------- test/unit/unit_test.cc | 8 +- test/unit/unit_test.h | 9 +- 5 files changed, 165 insertions(+), 69 deletions(-) diff --git a/test/common/modsecurity_test.cc b/test/common/modsecurity_test.cc index 227571919b..0f769054e8 100644 --- a/test/common/modsecurity_test.cc +++ b/test/common/modsecurity_test.cc @@ -93,13 +93,13 @@ bool ModSecurityTest::load_test_json(const std::string &file) { template -std::pair>* +void ModSecurityTest::load_tests(const std::string &path) { DIR *dir; struct dirent *ent; struct stat buffer; - if ((dir = opendir(path.c_str())) == NULL) { + if ((dir = opendir(path.c_str())) == nullptr) { /* if target is a file, use it as a single test. */ if (stat(path.c_str(), &buffer) == 0) { if (load_test_json(path) == false) { @@ -107,10 +107,10 @@ ModSecurityTest::load_tests(const std::string &path) { std::cout << std::endl; } } - return NULL; + return; } - while ((ent = readdir(dir)) != NULL) { + while ((ent = readdir(dir)) != nullptr) { std::string filename = ent->d_name; std::string json = ".json"; if (filename.size() < json.size() @@ -123,16 +123,15 @@ ModSecurityTest::load_tests(const std::string &path) { } } closedir(dir); - - return NULL; } template -std::pair>* ModSecurityTest::load_tests() { - return load_tests(this->target); +void ModSecurityTest::load_tests() { + load_tests(this->target); } + template void ModSecurityTest::cmd_options(int argc, char **argv) { int i = 1; @@ -144,6 +143,10 @@ void ModSecurityTest::cmd_options(int argc, char **argv) { i++; m_count_all = true; } + if (argc > i && strcmp(argv[i], "mtstress") == 0) { + i++; + m_test_multithreaded = true; + } if (std::getenv("AUTOMAKE_TESTS")) { m_automake_output = true; } diff --git a/test/common/modsecurity_test.h b/test/common/modsecurity_test.h index 79a168f71c..8b55a16c62 100644 --- a/test/common/modsecurity_test.h +++ b/test/common/modsecurity_test.h @@ -34,12 +34,13 @@ template class ModSecurityTest : ModSecurityTest() : m_test_number(0), m_automake_output(false), - m_count_all(false) { } + m_count_all(false), + m_test_multithreaded(false) { } std::string header(); void cmd_options(int, char **); - std::pair>* load_tests(); - std::pair>* load_tests(const std::string &path); + void load_tests(); + void load_tests(const std::string &path); bool load_test_json(const std::string &file); std::string target; @@ -48,6 +49,7 @@ template class ModSecurityTest : int m_test_number; bool m_automake_output; bool m_count_all; + bool m_test_multithreaded; }; } // namespace modsecurity_test diff --git a/test/unit/unit.cc b/test/unit/unit.cc index 46013ec744..7b62e7f7e3 100644 --- a/test/unit/unit.cc +++ b/test/unit/unit.cc @@ -15,7 +15,9 @@ #include #include - +#include +#include +#include #include #include #include @@ -38,6 +40,7 @@ using modsecurity_test::UnitTest; +using modsecurity_test::UnitTestResult; using modsecurity_test::ModSecurityTest; using modsecurity_test::ModSecurityTestResults; using modsecurity::actions::transformations::Transformation; @@ -53,64 +56,149 @@ void print_help() { } -void perform_unit_test(ModSecurityTest *test, UnitTest *t, - ModSecurityTestResults* res) { - std::string error; +struct OperatorTest { + using ItemType = Operator; + + static ItemType* init(const UnitTest &t) { + auto op = Operator::instantiate(t.name, t.param); + assert(op != nullptr); + + std::string error; + op->init(t.filename, &error); + + return op; + } + + static UnitTestResult eval(ItemType &op, const UnitTest &t) { + return {op.evaluate(nullptr, nullptr, t.input, nullptr), {}}; + } + + static bool check(const UnitTestResult &result, const UnitTest &t) { + return result.ret != t.ret; + } +}; + + +struct TransformationTest { + using ItemType = Transformation; + + static ItemType* init(const UnitTest &t) { + auto tfn = Transformation::instantiate("t:" + t.name); + assert(tfn != nullptr); + + return tfn; + } + + static UnitTestResult eval(ItemType &tfn, const UnitTest &t) { + return {1, tfn.evaluate(t.input, nullptr)}; + } + + static bool check(const UnitTestResult &result, const UnitTest &t) { + return result.output != t.output; + } +}; + + +template +UnitTestResult perform_unit_test_once(const UnitTest &t) { + std::unique_ptr item(TestType::init(t)); + assert(item.get() != nullptr); + + return TestType::eval(*item.get(), t); +} + + +template +UnitTestResult perform_unit_test_multithreaded(const UnitTest &t) { + + constexpr auto NUM_THREADS = 50; + constexpr auto ITERATIONS = 5'000; + + std::array threads; + std::array results; + + std::unique_ptr item(TestType::init(t)); + assert(item.get() != nullptr); + + for (auto i = 0; i != threads.size(); ++i) + { + auto &result = results[i]; + threads[i] = std::thread( + [&item, &t, &result]() + { + for (auto j = 0; j != ITERATIONS; ++j) + result = TestType::eval(*item.get(), t); + }); + } + + UnitTestResult ret; + + for (auto i = 0; i != threads.size(); ++i) + { + threads[i].join(); + if (TestType::check(results[i], t)) + ret = results[i]; // error value, keep iterating to join all threads + else if(i == 0) + ret = results[i]; // initial value + } + + return ret; // cppcheck-suppress uninitvar ; false positive, ret assigned at least once in previous loop +} + + +template +void perform_unit_test_helper(const ModSecurityTest &test, UnitTest &t, + ModSecurityTestResults &res) { + + if (!test.m_test_multithreaded) + t.result = perform_unit_test_once(t); + else + t.result = perform_unit_test_multithreaded(t); + + if (TestType::check(t.result, t)) { + res.push_back(&t); + if (test.m_automake_output) { + std::cout << "FAIL "; + } + } else if (test.m_automake_output) { + std::cout << "PASS "; + } +} + + +void perform_unit_test(const ModSecurityTest &test, UnitTest &t, + ModSecurityTestResults &res) { bool found = true; - if (test->m_automake_output) { + if (test.m_automake_output) { std::cout << ":test-result: "; } - if (t->resource.empty() == false) { - found = (std::find(resources.begin(), resources.end(), t->resource) - != resources.end()); + if (t.resource.empty() == false) { + found = std::find(resources.begin(), resources.end(), t.resource) + != resources.end(); } if (!found) { - t->skipped = true; - res->push_back(t); - if (test->m_automake_output) { + t.skipped = true; + res.push_back(&t); + if (test.m_automake_output) { std::cout << "SKIP "; } } - if (t->type == "op") { - Operator *op = Operator::instantiate(t->name, t->param); - op->init(t->filename, &error); - int ret = op->evaluate(NULL, NULL, t->input, NULL); - t->obtained = ret; - if (ret != t->ret) { - res->push_back(t); - if (test->m_automake_output) { - std::cout << "FAIL "; - } - } else if (test->m_automake_output) { - std::cout << "PASS "; - } - delete op; - } else if (t->type == "tfn") { - Transformation *tfn = Transformation::instantiate("t:" + t->name); - std::string ret = tfn->evaluate(t->input, NULL); - t->obtained = 1; - t->obtainedOutput = ret; - if (ret != t->output) { - res->push_back(t); - if (test->m_automake_output) { - std::cout << "FAIL "; - } - } else if (test->m_automake_output) { - std::cout << "PASS "; - } - delete tfn; + if (t.type == "op") { + perform_unit_test_helper(test, t, res); + } else if (t.type == "tfn") { + perform_unit_test_helper(test, t, res); } else { - std::cerr << "Failed. Test type is unknown: << " << t->type; + std::cerr << "Failed. Test type is unknown: << " << t.type; std::cerr << std::endl; } - if (test->m_automake_output) { - std::cout << t->name << " " - << modsecurity::utils::string::toHexIfNeeded(t->input) + if (test.m_automake_output) { + std::cout << t.name << " " + << modsecurity::utils::string::toHexIfNeeded(t.input) << std::endl; } } @@ -151,17 +239,15 @@ int main(int argc, char **argv) { test.load_tests("test-cases/secrules-language-tests/transformations"); } - for (std::pair *> a : test) { - std::vector *tests = a.second; - + for (auto& [filename, tests] : test) { total += tests->size(); - for (UnitTest *t : *tests) { + for (auto t : *tests) { ModSecurityTestResults r; if (!test.m_automake_output) { - std::cout << " " << a.first << "...\t"; + std::cout << " " << filename << "...\t"; } - perform_unit_test(&test, t, &r); + perform_unit_test(test, *t, r); if (!test.m_automake_output) { int skp = 0; @@ -191,7 +277,7 @@ int main(int argc, char **argv) { std::cout << "Total >> " << total << std::endl; } - for (UnitTest *t : results) { + for (const auto t : results) { std::cout << t->print() << std::endl; } @@ -216,8 +302,8 @@ int main(int argc, char **argv) { } for (auto a : test) { - auto *vec = a.second; - for(auto *t : *vec) + auto vec = a.second; + for(auto t : *vec) delete t; delete vec; } diff --git a/test/unit/unit_test.cc b/test/unit/unit_test.cc index d15533a3cc..dee4f1ae12 100644 --- a/test/unit/unit_test.cc +++ b/test/unit/unit_test.cc @@ -102,15 +102,15 @@ std::string UnitTest::print() { i << " \"param\": \"" << this->param << "\"" << std::endl; i << " \"output\": \"" << this->output << "\"" << std::endl; i << "}" << std::endl; - if (this->ret != this->obtained) { + if (this->ret != this->result.ret) { i << "Expecting: \"" << this->ret << "\" - returned: \""; - i << this->obtained << "\"" << std::endl; + i << this->result.ret << "\"" << std::endl; } - if (this->output != this->obtainedOutput) { + if (this->output != this->result.output) { i << "Expecting: \""; i << modsecurity::utils::string::toHexIfNeeded(this->output); i << "\" - returned: \""; - i << modsecurity::utils::string::toHexIfNeeded(this->obtainedOutput); + i << modsecurity::utils::string::toHexIfNeeded(this->result.output); i << "\""; i << std::endl; } diff --git a/test/unit/unit_test.h b/test/unit/unit_test.h index d200db5dbc..81d99d1424 100644 --- a/test/unit/unit_test.h +++ b/test/unit/unit_test.h @@ -25,6 +25,12 @@ namespace modsecurity_test { +class UnitTestResult { + public: + int ret; + std::string output; +}; + class UnitTest { public: static UnitTest *from_yajl_node(const yajl_val &); @@ -39,9 +45,8 @@ class UnitTest { std::string filename; std::string output; int ret; - int obtained; int skipped; - std::string obtainedOutput; + UnitTestResult result; }; } // namespace modsecurity_test