diff --git a/casbin/enforcer.cpp b/casbin/enforcer.cpp index a04308cd..a44a739e 100644 --- a/casbin/enforcer.cpp +++ b/casbin/enforcer.cpp @@ -102,7 +102,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::shared_ptr m_log.LogPrint("Policy Rule: ", p_vals); if(p_tokens.size() != p_vals.size()) return false; - m_func_map.evalator->Clean(m_model->m["p"]); + m_func_map.evalator->Clean(m_model->m["p"], false); m_func_map.evalator->InitialObject("p"); for(int j = 0 ; j < p_tokens.size() ; j++) { size_t index = p_tokens[j].find("_"); @@ -173,7 +173,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::shared_ptr } else { // Push initial value for p in symbol table // If p don't in symbol table, the evaluate result will be invalid. - m_func_map.evalator->Clean(m_model->m["p"]); + m_func_map.evalator->Clean(m_model->m["p"], false); m_func_map.evalator->InitialObject("p"); for(int j = 0 ; j < p_tokens.size() ; j++) { size_t index = p_tokens[j].find("_"); @@ -367,6 +367,11 @@ void Enforcer::SetWatcher(std::shared_ptr watcher) { watcher->SetUpdateCallback(func); } +// SetWatcher sets the current evaluator. +void Enforcer::SetEvaluator(std::shared_ptr evaluator) { + this->m_evalator = evaluator; +} + // GetRoleManager gets the current role manager. std::shared_ptr Enforcer ::GetRoleManager() { return this->rm; diff --git a/casbin/model/evaluator.cpp b/casbin/model/evaluator.cpp index 760ba356..4ec79cab 100644 --- a/casbin/model/evaluator.cpp +++ b/casbin/model/evaluator.cpp @@ -20,32 +20,45 @@ namespace casbin { bool ExprtkEvaluator::Eval(const std::string& expression_string) { - expression.register_symbol_table(symbol_table); - // replace (&& -> and), (|| -> or) - auto replaced_string = std::regex_replace(expression_string, std::regex("&&"), "and"); - replaced_string = std::regex_replace(replaced_string, std::regex("\\|{2}"), "or"); - // replace string "" -> '' - replaced_string = std::regex_replace(replaced_string, std::regex("\""), "\'"); + if (this->expression_string_ != expression_string) { + this->expression_string_ = expression_string; + // replace (&& -> and), (|| -> or) + auto replaced_string = std::regex_replace(expression_string, std::regex("&&"), "and"); + replaced_string = std::regex_replace(replaced_string, std::regex("\\|{2}"), "or"); + // replace string "" -> '' + replaced_string = std::regex_replace(replaced_string, std::regex("\""), "\'"); + + return parser.compile(replaced_string, expression); + } - return parser.compile(replaced_string, expression); + return this->parser.error_count() == 0; } - void ExprtkEvaluator::InitialObject(std::string identifier) { + void ExprtkEvaluator::InitialObject(const std::string& identifier) { // symbol_table.add_stringvar(""); } - void ExprtkEvaluator::PushObjectString(std::string target, std::string proprity, const std::string& var) { + void ExprtkEvaluator::PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) { auto identifier = target + "." + proprity; - this->symbol_table.add_stringvar(identifier, const_cast(var)); + + if (!symbol_table.symbol_exists(identifier)) { + identifiers_[identifier] = std::make_unique(""); + this->symbol_table.add_stringvar(identifier, *identifiers_[identifier]); + } + symbol_table.get_stringvar(identifier)->ref() = var; } - void ExprtkEvaluator::PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) { + void ExprtkEvaluator::PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) { auto identifier = target + "." + proprity; // this->symbol_table.add_stringvar(identifier, const_cast(var)); } void ExprtkEvaluator::LoadFunctions() { - + AddFunction("keyMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch, 2)); + AddFunction("keyMatch2", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch2, 2)); + AddFunction("keyMatch3", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch3, 2)); + AddFunction("regexMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::RegexMatch, 2)); + AddFunction("ipMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::IpMatch, 2)); } void ExprtkEvaluator::LoadGFunction(std::shared_ptr rm, const std::string& name, int narg) { @@ -69,31 +82,29 @@ namespace casbin { } bool ExprtkEvaluator::GetBoolen() { - return expression.value(); + return bool(this->expression); } float ExprtkEvaluator::GetFloat() { return expression.value(); } - void ExprtkEvaluator::Clean(AssertionMap& section) { - for (auto& [assertion_name, assertion]: section.assertion_map) { - std::vector raw_tokens = assertion->tokens; - - for(int j = 0 ; j < raw_tokens.size() ; j++) { - size_t index = raw_tokens[j].find("_"); - std::string token = raw_tokens[j].substr(index + 1); - auto identifier = assertion_name + "." + token; - if (symbol_table.get_stringvar(identifier) != nullptr) { - symbol_table.remove_stringvar(identifier); - } - } + void ExprtkEvaluator::Clean(AssertionMap& section, bool after_enforce) { + if (after_enforce == false) { + return; } + + this->symbol_table.clear(); + this->expression_string_ = ""; + this->Functions.clear(); + this->identifiers_.clear(); } void ExprtkEvaluator::AddFunction(const std::string& func_name, std::shared_ptr func) { - this->Functions.push_back(func); - symbol_table.add_function(func_name, *func); + if (func != nullptr) { + this->Functions.push_back(func); + symbol_table.add_function(func_name, *func); + } } void ExprtkEvaluator::PrintSymbol() { @@ -104,21 +115,24 @@ namespace casbin { for (auto& var: var_list) { printf(" %s: %s\n" , var.c_str(), symbol_table.get_stringvar(var)->ref().c_str()); } + printf("Current error: %s\n", parser.error().c_str()); + // printf("Current exprsio string: %s\n", parser.current_token); + printf("Current value: %d\n", bool(this->expression)); } bool DuktapeEvaluator::Eval(const std::string& expression) { return casbin::Eval(scope, expression); } - void DuktapeEvaluator::InitialObject(std::string identifier) { + void DuktapeEvaluator::InitialObject(const std::string& identifier) { PushObject(scope, identifier); } - void DuktapeEvaluator::PushObjectString(std::string target, std::string proprity, const std::string& var) { + void DuktapeEvaluator::PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) { PushStringPropToObject(scope, target, var, proprity); } - void DuktapeEvaluator::PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) { + void DuktapeEvaluator::PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) { PushObject(scope, proprity); PushObjectPropFromJson(scope, var, proprity); PushObjectPropToObject(scope, target, proprity); @@ -183,7 +197,7 @@ namespace casbin { return casbin::GetFloat(scope); } - void DuktapeEvaluator::Clean(AssertionMap& section) { + void DuktapeEvaluator::Clean(AssertionMap& section, bool after_enforce) { if (scope != nullptr) { for (auto& [assertion_name, assertion]: section.assertion_map) { std::vector raw_tokens = assertion->tokens; diff --git a/include/casbin/enforcer.h b/include/casbin/enforcer.h index 95775374..dd61bae9 100644 --- a/include/casbin/enforcer.h +++ b/include/casbin/enforcer.h @@ -127,6 +127,8 @@ class Enforcer : public IEnforcer { void SetAdapter(std::shared_ptr adapter); // SetWatcher sets the current watcher. void SetWatcher(std::shared_ptr watcher); + // SetWatcher sets the current watcher. + void SetEvaluator(std::shared_ptr evaluator); // GetRoleManager gets the current role manager. std::shared_ptr GetRoleManager(); // SetRoleManager sets the current role manager. diff --git a/include/casbin/model/evaluator.h b/include/casbin/model/evaluator.h index 7a3a858e..67904033 100644 --- a/include/casbin/model/evaluator.h +++ b/include/casbin/model/evaluator.h @@ -34,11 +34,11 @@ namespace casbin { std::list func_list; virtual bool Eval(const std::string& expression) = 0; - virtual void InitialObject(std::string target) = 0; + virtual void InitialObject(const std::string& target) = 0; - virtual void PushObjectString(std::string target, std::string proprity, const std::string& var) = 0; + virtual void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) = 0; - virtual void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) = 0; + virtual void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) = 0; virtual void LoadFunctions() = 0; @@ -52,23 +52,28 @@ namespace casbin { virtual float GetFloat() = 0; - virtual void Clean(AssertionMap& section) = 0; + virtual void Clean(AssertionMap& section, bool after_enforce = true) = 0; }; class ExprtkEvaluator : public IEvaluator { private: + std::string expression_string_; symbol_table_t symbol_table; expression_t expression; parser_t parser; std::vector> Functions; + std::unordered_map> identifiers_; public: + ExprtkEvaluator() { + this->expression.register_symbol_table(this->symbol_table); + }; bool Eval(const std::string& expression); - void InitialObject(std::string target); + void InitialObject(const std::string& target); - void PushObjectString(std::string target, std::string proprity, const std::string& var); + void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var); - void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var); + void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var); void LoadFunctions(); @@ -82,7 +87,7 @@ namespace casbin { float GetFloat(); - void Clean(AssertionMap& section); + void Clean(AssertionMap& section, bool after_enforce = true); void PrintSymbol(); @@ -103,11 +108,11 @@ namespace casbin { bool Eval(const std::string& expression); - void InitialObject(std::string target); + void InitialObject(const std::string& target); - void PushObjectString(std::string target, std::string proprity, const std::string& var); + void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var); - void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var); + void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var); void LoadFunctions(); @@ -121,7 +126,7 @@ namespace casbin { float GetFloat(); - void Clean(AssertionMap& section); + void Clean(AssertionMap& section, bool after_enforce = true); // For duktape void AddFunction(const std::string& func_name, Function f, Index nargs); diff --git a/include/casbin/model/exprtk_config.h b/include/casbin/model/exprtk_config.h index 71fe6aff..360170c2 100644 --- a/include/casbin/model/exprtk_config.h +++ b/include/casbin/model/exprtk_config.h @@ -21,6 +21,8 @@ #include "casbin/exprtk/exprtk.hpp" #include "casbin/rbac/role_manager.h" +#include "casbin/rbac/default_role_manager.h" +#include "casbin/util/util.h" namespace casbin { using numerical_type = float; @@ -98,19 +100,101 @@ namespace casbin { } }; + + struct ExprtkOtherFunction : public exprtk::igeneric_function + { + typedef typename exprtk::igeneric_function::generic_type + generic_type; + + typedef typename generic_type::scalar_view scalar_t; + typedef typename generic_type::vector_view vector_t; + typedef typename generic_type::string_view string_t; + + typedef typename exprtk::igeneric_function::parameter_list_t + parameter_list_t; + private: + casbin::MatchingFunc func_; + public: + ExprtkOtherFunction(const std::string& idenfier, casbin::MatchingFunc func) + : exprtk::igeneric_function(idenfier), func_(func) + {} + + ExprtkOtherFunction() + : exprtk::igeneric_function("ss") + {} + + inline numerical_type operator()(parameter_list_t parameters) { + bool res = false; + + // check value cnt + if (parameters.size() != 2) { + return numerical_type(res); + } + + // check value type + for (std::size_t i = 0; i < parameters.size(); ++i) { + generic_type& gt = parameters[i]; + + if (generic_type::e_scalar == gt.type) { + return numerical_type(res); + } + else if (generic_type::e_vector == gt.type) { + return numerical_type(res); + } + } + + std::string name1 = exprtk::to_str(string_t(parameters[0])); + std::string name2 = exprtk::to_str(string_t(parameters[1])); + + if(this->func_ == nullptr) + res = name1 == name2; + else { + res = this->func_(name1, name2); + } + + return numerical_type(res); + } + }; + enum class ExprtkFunctionType { + Unknown, Gfunction, + KeyMatch, + KeyMatch2, + KeyMatch3, + RegexMatch, + IpMatch, }; class ExprtkFunctionFactory { public: static std::shared_ptr GetExprtkFunction(ExprtkFunctionType type, int narg, std::shared_ptr rm = nullptr) { - if (type == ExprtkFunctionType::Gfunction) { - std::string idenfier(narg, 'S'); - return std::make_shared(idenfier, rm); - } else { - return nullptr; + std::string idenfier(narg, 'S'); + std::shared_ptr func = nullptr; + switch (type) { + case ExprtkFunctionType::Gfunction: + func = std::make_shared(idenfier, rm); + break; + case ExprtkFunctionType::KeyMatch: + func.reset(new ExprtkOtherFunction(idenfier, KeyMatch)); + break; + case ExprtkFunctionType::KeyMatch2: + func.reset(new ExprtkOtherFunction(idenfier, KeyMatch)); + break; + case ExprtkFunctionType::KeyMatch3: + func.reset(new ExprtkOtherFunction(idenfier, KeyMatch)); + break; + case ExprtkFunctionType::IpMatch: + func.reset(new ExprtkOtherFunction(idenfier, KeyMatch)); + break; + case ExprtkFunctionType::RegexMatch: + func.reset(new ExprtkOtherFunction(idenfier, KeyMatch)); + break; + default: + func = nullptr; } + + return func; } }; } diff --git a/tests/benchmarks/model_b.cpp b/tests/benchmarks/model_b.cpp index 43684321..fdcfe433 100644 --- a/tests/benchmarks/model_b.cpp +++ b/tests/benchmarks/model_b.cpp @@ -37,29 +37,42 @@ static void BenchmarkRaw(benchmark::State& state) { BENCHMARK(BenchmarkRaw); +template static void BenchmarkBasicModel(benchmark::State& state) { casbin::Enforcer e(basic_model_path, basic_policy_path); + auto evaluator = std::make_shared(); + e.SetEvaluator(evaluator); + casbin::DataList params = {"alice", "data1", "read"}; for(auto _ : state) e.Enforce(params); } -BENCHMARK(BenchmarkBasicModel); +BENCHMARK_TEMPLATE(BenchmarkBasicModel, casbin::DuktapeEvaluator); +BENCHMARK_TEMPLATE(BenchmarkBasicModel, casbin::ExprtkEvaluator); +template static void BenchmarkRBACModel(benchmark::State& state) { casbin::Enforcer e(rbac_model_path, rbac_policy_path); + auto evaluator = std::make_shared(); + e.SetEvaluator(evaluator); + casbin::DataList params = {"alice", "data2", "read"}; for (auto _ : state) e.Enforce(params); } -BENCHMARK(BenchmarkRBACModel); +BENCHMARK_TEMPLATE(BenchmarkRBACModel, casbin::DuktapeEvaluator); +BENCHMARK_TEMPLATE(BenchmarkRBACModel, casbin::ExprtkEvaluator); +template static void BenchmarkRBACModelSmall(benchmark::State& state) { casbin::Enforcer e(rbac_model_path); + auto evaluator = std::make_shared(); + e.SetEvaluator(evaluator); // 100 roles, 10 resources. for(int i = 0; i < 100; ++i) @@ -74,16 +87,21 @@ static void BenchmarkRBACModelSmall(benchmark::State& state) { e.Enforce(params); } -BENCHMARK(BenchmarkRBACModelSmall); +BENCHMARK_TEMPLATE(BenchmarkRBACModelSmall, casbin::DuktapeEvaluator); +BENCHMARK_TEMPLATE(BenchmarkRBACModelSmall, casbin::ExprtkEvaluator); +template static void BenchmarkRBACModelWithResourceRoles(benchmark::State& state) { casbin::Enforcer e(rbac_with_resource_roles_model_path, rbac_with_resource_roles_policy_path); + auto evaluator = std::make_shared(); + e.SetEvaluator(evaluator); casbin::DataList params = {"alice", "data1", "read"}; for (auto _ : state) e.Enforce(params); } -BENCHMARK(BenchmarkRBACModelWithResourceRoles); +BENCHMARK_TEMPLATE(BenchmarkRBACModelWithResourceRoles, casbin::DuktapeEvaluator); +BENCHMARK_TEMPLATE(BenchmarkRBACModelWithResourceRoles, casbin::ExprtkEvaluator); static void BenchmarkRBACModelWithDomains(benchmark::State& state) { casbin::Enforcer e(rbac_with_domains_model_path, rbac_with_domains_policy_path); diff --git a/tests/model_enforcer_test.cpp b/tests/model_enforcer_test.cpp index 59c5c129..7746d96d 100644 --- a/tests/model_enforcer_test.cpp +++ b/tests/model_enforcer_test.cpp @@ -829,6 +829,23 @@ TEST(TestModelEnforcer, TestRBACModelWithPattern) { evaluator = InitializeParams("bob", "/pen/2", "GET"); TestEnforce(e, evaluator, true); + evaluator = InitializeParams("alice", "/book/1", "GET"); + TestEnforce(e, evaluator, true); + evaluator = InitializeParams("alice", "/book/2", "GET"); + TestEnforce(e, evaluator, true); + evaluator = InitializeParams("alice", "/pen/1", "GET"); + TestEnforce(e, evaluator, true); + evaluator = InitializeParams("alice", "/pen/2", "GET"); + TestEnforce(e, evaluator, false); + evaluator = InitializeParams("bob", "/book/1", "GET"); + TestEnforce(e, evaluator, false); + evaluator = InitializeParams("bob", "/book/2", "GET"); + TestEnforce(e, evaluator, false); + evaluator = InitializeParams("bob", "/pen/1", "GET"); + TestEnforce(e, evaluator, true); + evaluator = InitializeParams("bob", "/pen/2", "GET"); + TestEnforce(e, evaluator, true); + // AddMatchingFunc() is actually setting a function because only one function is allowed, // so when we set "KeyMatch3", we are actually replacing "KeyMatch2" with "KeyMatch3". e.AddNamedMatchingFunc("p", "", casbin::KeyMatch3);