Skip to content

Commit

Permalink
feat: Benchmarks for exprtk evaluator (#189)
Browse files Browse the repository at this point in the history
* feat: exprtk evaluator support RBAC with domain

Signed-off-by: stonex <1479765922@qq.com>

* feat: exprtk evaluator support rbac with pattern

Signed-off-by: stonex <1479765922@qq.com>

* feat: add benchmark for exprtk evaluator

Signed-off-by: stonex <1479765922@qq.com>

* perf: only compile once in Eval when eval string don't change.

Signed-off-by: stonex <1479765922@qq.com>

* fix: repair regex function and add function judge whether's null

Signed-off-by: stonex <1479765922@qq.com>

* fix: clean all symbol table to avoid haning pointer in exprtk.

Signed-off-by: stonex <1479765922@qq.com>

* chore: use switch for select

Signed-off-by: stonex <1479765922@qq.com>
  • Loading branch information
sheny1xuan authored Mar 15, 2022
1 parent dc1499d commit 9c7b326
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 54 deletions.
9 changes: 7 additions & 2 deletions casbin/enforcer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::shared_ptr<IEvaluator>
m_log.LogPrint("Policy Rule: ", p_vals);
if(p_tokens.size() != p_vals.size())
return false;
m_func_map.evalator->Clean(m_model->m["p"]);
m_func_map.evalator->Clean(m_model->m["p"], false);
m_func_map.evalator->InitialObject("p");
for(int j = 0 ; j < p_tokens.size() ; j++) {
size_t index = p_tokens[j].find("_");
Expand Down Expand Up @@ -173,7 +173,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::shared_ptr<IEvaluator>
} else {
// Push initial value for p in symbol table
// If p don't in symbol table, the evaluate result will be invalid.
m_func_map.evalator->Clean(m_model->m["p"]);
m_func_map.evalator->Clean(m_model->m["p"], false);
m_func_map.evalator->InitialObject("p");
for(int j = 0 ; j < p_tokens.size() ; j++) {
size_t index = p_tokens[j].find("_");
Expand Down Expand Up @@ -367,6 +367,11 @@ void Enforcer::SetWatcher(std::shared_ptr<Watcher> watcher) {
watcher->SetUpdateCallback(func);
}

// SetWatcher sets the current evaluator.
void Enforcer::SetEvaluator(std::shared_ptr<IEvaluator> evaluator) {
this->m_evalator = evaluator;
}

// GetRoleManager gets the current role manager.
std::shared_ptr<RoleManager> Enforcer ::GetRoleManager() {
return this->rm;
Expand Down
76 changes: 45 additions & 31 deletions casbin/model/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,45 @@

namespace casbin {
bool ExprtkEvaluator::Eval(const std::string& expression_string) {
expression.register_symbol_table(symbol_table);
// replace (&& -> and), (|| -> or)
auto replaced_string = std::regex_replace(expression_string, std::regex("&&"), "and");
replaced_string = std::regex_replace(replaced_string, std::regex("\\|{2}"), "or");
// replace string "" -> ''
replaced_string = std::regex_replace(replaced_string, std::regex("\""), "\'");
if (this->expression_string_ != expression_string) {
this->expression_string_ = expression_string;
// replace (&& -> and), (|| -> or)
auto replaced_string = std::regex_replace(expression_string, std::regex("&&"), "and");
replaced_string = std::regex_replace(replaced_string, std::regex("\\|{2}"), "or");
// replace string "" -> ''
replaced_string = std::regex_replace(replaced_string, std::regex("\""), "\'");

return parser.compile(replaced_string, expression);
}

return parser.compile(replaced_string, expression);
return this->parser.error_count() == 0;
}

void ExprtkEvaluator::InitialObject(std::string identifier) {
void ExprtkEvaluator::InitialObject(const std::string& identifier) {
// symbol_table.add_stringvar("");
}

void ExprtkEvaluator::PushObjectString(std::string target, std::string proprity, const std::string& var) {
void ExprtkEvaluator::PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) {
auto identifier = target + "." + proprity;
this->symbol_table.add_stringvar(identifier, const_cast<std::string&>(var));

if (!symbol_table.symbol_exists(identifier)) {
identifiers_[identifier] = std::make_unique<std::string>("");
this->symbol_table.add_stringvar(identifier, *identifiers_[identifier]);
}
symbol_table.get_stringvar(identifier)->ref() = var;
}

void ExprtkEvaluator::PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) {
void ExprtkEvaluator::PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) {
auto identifier = target + "." + proprity;
// this->symbol_table.add_stringvar(identifier, const_cast<std::string&>(var));
}

void ExprtkEvaluator::LoadFunctions() {

AddFunction("keyMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch, 2));
AddFunction("keyMatch2", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch2, 2));
AddFunction("keyMatch3", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch3, 2));
AddFunction("regexMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::RegexMatch, 2));
AddFunction("ipMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::IpMatch, 2));
}

void ExprtkEvaluator::LoadGFunction(std::shared_ptr<RoleManager> rm, const std::string& name, int narg) {
Expand All @@ -69,31 +82,29 @@ namespace casbin {
}

bool ExprtkEvaluator::GetBoolen() {
return expression.value();
return bool(this->expression);
}

float ExprtkEvaluator::GetFloat() {
return expression.value();
}

void ExprtkEvaluator::Clean(AssertionMap& section) {
for (auto& [assertion_name, assertion]: section.assertion_map) {
std::vector<std::string> raw_tokens = assertion->tokens;

for(int j = 0 ; j < raw_tokens.size() ; j++) {
size_t index = raw_tokens[j].find("_");
std::string token = raw_tokens[j].substr(index + 1);
auto identifier = assertion_name + "." + token;
if (symbol_table.get_stringvar(identifier) != nullptr) {
symbol_table.remove_stringvar(identifier);
}
}
void ExprtkEvaluator::Clean(AssertionMap& section, bool after_enforce) {
if (after_enforce == false) {
return;
}

this->symbol_table.clear();
this->expression_string_ = "";
this->Functions.clear();
this->identifiers_.clear();
}

void ExprtkEvaluator::AddFunction(const std::string& func_name, std::shared_ptr<exprtk_func_t> func) {
this->Functions.push_back(func);
symbol_table.add_function(func_name, *func);
if (func != nullptr) {
this->Functions.push_back(func);
symbol_table.add_function(func_name, *func);
}
}

void ExprtkEvaluator::PrintSymbol() {
Expand All @@ -104,21 +115,24 @@ namespace casbin {
for (auto& var: var_list) {
printf(" %s: %s\n" , var.c_str(), symbol_table.get_stringvar(var)->ref().c_str());
}
printf("Current error: %s\n", parser.error().c_str());
// printf("Current exprsio string: %s\n", parser.current_token);
printf("Current value: %d\n", bool(this->expression));
}

bool DuktapeEvaluator::Eval(const std::string& expression) {
return casbin::Eval(scope, expression);
}

void DuktapeEvaluator::InitialObject(std::string identifier) {
void DuktapeEvaluator::InitialObject(const std::string& identifier) {
PushObject(scope, identifier);
}

void DuktapeEvaluator::PushObjectString(std::string target, std::string proprity, const std::string& var) {
void DuktapeEvaluator::PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) {
PushStringPropToObject(scope, target, var, proprity);
}

void DuktapeEvaluator::PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) {
void DuktapeEvaluator::PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) {
PushObject(scope, proprity);
PushObjectPropFromJson(scope, var, proprity);
PushObjectPropToObject(scope, target, proprity);
Expand Down Expand Up @@ -183,7 +197,7 @@ namespace casbin {
return casbin::GetFloat(scope);
}

void DuktapeEvaluator::Clean(AssertionMap& section) {
void DuktapeEvaluator::Clean(AssertionMap& section, bool after_enforce) {
if (scope != nullptr) {
for (auto& [assertion_name, assertion]: section.assertion_map) {
std::vector<std::string> raw_tokens = assertion->tokens;
Expand Down
2 changes: 2 additions & 0 deletions include/casbin/enforcer.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class Enforcer : public IEnforcer {
void SetAdapter(std::shared_ptr<Adapter> adapter);
// SetWatcher sets the current watcher.
void SetWatcher(std::shared_ptr<Watcher> watcher);
// SetWatcher sets the current watcher.
void SetEvaluator(std::shared_ptr<IEvaluator> evaluator);
// GetRoleManager gets the current role manager.
std::shared_ptr<RoleManager> GetRoleManager();
// SetRoleManager sets the current role manager.
Expand Down
29 changes: 17 additions & 12 deletions include/casbin/model/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ namespace casbin {
std::list<std::string> func_list;
virtual bool Eval(const std::string& expression) = 0;

virtual void InitialObject(std::string target) = 0;
virtual void InitialObject(const std::string& target) = 0;

virtual void PushObjectString(std::string target, std::string proprity, const std::string& var) = 0;
virtual void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) = 0;

virtual void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var) = 0;
virtual void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) = 0;

virtual void LoadFunctions() = 0;

Expand All @@ -52,23 +52,28 @@ namespace casbin {

virtual float GetFloat() = 0;

virtual void Clean(AssertionMap& section) = 0;
virtual void Clean(AssertionMap& section, bool after_enforce = true) = 0;
};

class ExprtkEvaluator : public IEvaluator {
private:
std::string expression_string_;
symbol_table_t symbol_table;
expression_t expression;
parser_t parser;
std::vector<std::shared_ptr<exprtk_func_t>> Functions;
std::unordered_map<std::string, std::unique_ptr<std::string>> identifiers_;
public:
ExprtkEvaluator() {
this->expression.register_symbol_table(this->symbol_table);
};
bool Eval(const std::string& expression);

void InitialObject(std::string target);
void InitialObject(const std::string& target);

void PushObjectString(std::string target, std::string proprity, const std::string& var);
void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var);

void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var);
void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var);

void LoadFunctions();

Expand All @@ -82,7 +87,7 @@ namespace casbin {

float GetFloat();

void Clean(AssertionMap& section);
void Clean(AssertionMap& section, bool after_enforce = true);

void PrintSymbol();

Expand All @@ -103,11 +108,11 @@ namespace casbin {

bool Eval(const std::string& expression);

void InitialObject(std::string target);
void InitialObject(const std::string& target);

void PushObjectString(std::string target, std::string proprity, const std::string& var);
void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var);

void PushObjectJson(std::string target, std::string proprity, const nlohmann::json& var);
void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var);

void LoadFunctions();

Expand All @@ -121,7 +126,7 @@ namespace casbin {

float GetFloat();

void Clean(AssertionMap& section);
void Clean(AssertionMap& section, bool after_enforce = true);
// For duktape
void AddFunction(const std::string& func_name, Function f, Index nargs);

Expand Down
94 changes: 89 additions & 5 deletions include/casbin/model/exprtk_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "casbin/exprtk/exprtk.hpp"
#include "casbin/rbac/role_manager.h"
#include "casbin/rbac/default_role_manager.h"
#include "casbin/util/util.h"

namespace casbin {
using numerical_type = float;
Expand Down Expand Up @@ -98,19 +100,101 @@ namespace casbin {
}
};


struct ExprtkOtherFunction : public exprtk::igeneric_function<numerical_type>
{
typedef typename exprtk::igeneric_function<numerical_type>::generic_type
generic_type;

typedef typename generic_type::scalar_view scalar_t;
typedef typename generic_type::vector_view vector_t;
typedef typename generic_type::string_view string_t;

typedef typename exprtk::igeneric_function<numerical_type>::parameter_list_t
parameter_list_t;
private:
casbin::MatchingFunc func_;
public:
ExprtkOtherFunction(const std::string& idenfier, casbin::MatchingFunc func)
: exprtk::igeneric_function<numerical_type>(idenfier), func_(func)
{}

ExprtkOtherFunction()
: exprtk::igeneric_function<numerical_type>("ss")
{}

inline numerical_type operator()(parameter_list_t parameters) {
bool res = false;

// check value cnt
if (parameters.size() != 2) {
return numerical_type(res);
}

// check value type
for (std::size_t i = 0; i < parameters.size(); ++i) {
generic_type& gt = parameters[i];

if (generic_type::e_scalar == gt.type) {
return numerical_type(res);
}
else if (generic_type::e_vector == gt.type) {
return numerical_type(res);
}
}

std::string name1 = exprtk::to_str(string_t(parameters[0]));
std::string name2 = exprtk::to_str(string_t(parameters[1]));

if(this->func_ == nullptr)
res = name1 == name2;
else {
res = this->func_(name1, name2);
}

return numerical_type(res);
}
};

enum class ExprtkFunctionType {
Unknown,
Gfunction,
KeyMatch,
KeyMatch2,
KeyMatch3,
RegexMatch,
IpMatch,
};

class ExprtkFunctionFactory {
public:
static std::shared_ptr<exprtk_func_t> GetExprtkFunction(ExprtkFunctionType type, int narg, std::shared_ptr<RoleManager> rm = nullptr) {
if (type == ExprtkFunctionType::Gfunction) {
std::string idenfier(narg, 'S');
return std::make_shared<ExprtkGFunction>(idenfier, rm);
} else {
return nullptr;
std::string idenfier(narg, 'S');
std::shared_ptr<exprtk_func_t> func = nullptr;
switch (type) {
case ExprtkFunctionType::Gfunction:
func = std::make_shared<ExprtkGFunction>(idenfier, rm);
break;
case ExprtkFunctionType::KeyMatch:
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
break;
case ExprtkFunctionType::KeyMatch2:
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
break;
case ExprtkFunctionType::KeyMatch3:
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
break;
case ExprtkFunctionType::IpMatch:
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
break;
case ExprtkFunctionType::RegexMatch:
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
break;
default:
func = nullptr;
}

return func;
}
};
}
Expand Down
Loading

0 comments on commit 9c7b326

Please sign in to comment.