Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ impl Engine {
.set_functions(gather_functions(&self.modules)?);
self.interpreter.gather_rules()?;
self.interpreter.process_imports()?;
self.interpreter.constant_fold()?;
self.prepared = true;
}

Expand Down
168 changes: 154 additions & 14 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
Value(Value),
}

#[derive(Debug, Clone)]
struct Optimized {
data: Value,
processed: BTreeSet<Ref<Rule>>,
processed_paths: Value,
}

#[derive(Debug, Clone)]
pub struct Interpreter {
modules: Vec<Ref<Module>>,
Expand Down Expand Up @@ -73,6 +80,10 @@
gather_prints: bool,
prints: Vec<String>,
rule_paths: Set<String>,
is_constant_folding: bool,
optimized: Option<Optimized>,
has_side_effects: bool,
rule_value_conflict_error: bool,
}

impl Default for Interpreter {
Expand Down Expand Up @@ -197,6 +208,10 @@
gather_prints: false,
prints: Vec::default(),
rule_paths: Set::new(),
is_constant_folding: false,
has_side_effects: false,
optimized: None,
rule_value_conflict_error: false,
}
}

Expand Down Expand Up @@ -255,9 +270,16 @@
}

pub fn clean_internal_evaluation_state(&mut self) {
self.data = self.init_data.clone();
self.processed.clear();
self.processed_paths = Value::new_object();
if let Some(optimized) = &self.optimized {
self.data = optimized.data.clone();
// TODO: Check use of processed and processed_paths
self.processed = optimized.processed.clone();
self.processed_paths = optimized.processed_paths.clone();
} else {
self.data = self.init_data.clone();
self.processed.clear();
self.processed_paths = Value::new_object();
}
self.loop_var_values.clear();
self.scopes = vec![Scope::new()];
self.contexts = vec![];
Expand Down Expand Up @@ -1281,6 +1303,7 @@
// Mark modified rules as processed.
if let Some(rules) = self.rules.get(&target) {
for r in rules {
// TODO: check if this is correct for constant folding
self.processed.insert(r.clone());
}
}
Expand Down Expand Up @@ -1507,6 +1530,15 @@

self.scopes.pop();

// When constant folding, evaluate one iteration of the loop so as to
// fold constants within the loop body. But return false to prevent
// the loop itself from ptentially being treated as a constant.
//if self.is_constant_folding {
// return Ok(false);
//}
// Ignore previous comment.
// TODO: Implement expression value caching for constant values.

// Return true if at least on iteration returned true
Ok(result)
}
Expand Down Expand Up @@ -2055,6 +2087,8 @@
}

fn eval_array_compr(&mut self, term: &ExprRef, query: &Ref<Query>) -> Result<Value> {
let hse = self.has_side_effects;
self.has_side_effects = false;
// Push new context
self.contexts.push(Context {
output_expr: Some(term.clone()),
Expand All @@ -2066,13 +2100,21 @@
// Evaluate body first.
self.eval_query(query)?;

self.has_side_effects = hse || self.has_side_effects;
if self.is_constant_folding && self.has_side_effects {
return Ok(Value::Undefined);
}

match self.contexts.pop() {
Some(ctx) => Ok(ctx.value),
None => bail!("internal error: context already popped"),
}
}

fn eval_set_compr(&mut self, term: &ExprRef, query: &Ref<Query>) -> Result<Value> {
let hse = self.has_side_effects;
self.has_side_effects = false;

// Push new context
self.contexts.push(Context {
output_expr: Some(term.clone()),
Expand All @@ -2083,6 +2125,11 @@

self.eval_query(query)?;

self.has_side_effects = hse || self.has_side_effects;
if self.is_constant_folding && self.has_side_effects {
return Ok(Value::Undefined);
}

match self.contexts.pop() {
Some(ctx) => Ok(ctx.value),
None => bail!("internal error: context already popped"),
Expand All @@ -2095,6 +2142,9 @@
value: &ExprRef,
query: &Ref<Query>,
) -> Result<Value> {
let hse = self.has_side_effects;
self.has_side_effects = false;

// Push new context
self.contexts.push(Context {
key_expr: Some(key.clone()),
Expand All @@ -2106,6 +2156,11 @@

self.eval_query(query)?;

self.has_side_effects = hse || self.has_side_effects;
if self.is_constant_folding && self.has_side_effects {
return Ok(Value::Undefined);
}

match self.contexts.pop() {
Some(ctx) => Ok(ctx.value),
None => bail!("internal error: context already popped"),
Expand Down Expand Up @@ -2137,6 +2192,12 @@
return Ok(Value::Undefined);
}

// TODO: Allow builtins that can be constant folded
if self.is_constant_folding {
self.has_side_effects = true;
return Ok(Value::Undefined);
}

let cache = builtins::must_cache(name);
if let Some(name) = &cache {
if let Some(v) = self.builtins_cache.get(&(name, args.clone())) {
Expand Down Expand Up @@ -2316,6 +2377,10 @@
extension = Some(ext);
(&empty, None)
} else if fcn_path == "print" {
if self.is_constant_folding {
// Ignore side-effects in constant folding.
return Ok(Value::Undefined);
}
return self.eval_print(span, params, param_values);
}
// Look up builtin function.
Expand Down Expand Up @@ -2344,6 +2409,11 @@
}

if let Some((nargs, ext)) = extension {
if self.is_constant_folding {
// Extensions are not supported in constant folding.
return Ok(Value::Undefined);
}

if param_values.len() != *nargs as usize {
bail!(span.error("incorrect number of parameters supplied to extension"));
}
Expand Down Expand Up @@ -2621,15 +2691,21 @@
}

// Evaluate the associated default rules after non-default rules
if let Some(rules) = self.default_rules.get(&path) {
matched = true;
for (r, _) in rules.clone() {
if !self.processed.contains(&r) {
let module = self.get_rule_module(&r)?;
let prev_module = self.set_current_module(Some(module))?;
self.eval_default_rule(&r)?;
self.set_current_module(prev_module)?;
if self.is_constant_folding {
// We don't want to evaluate default rules at this point.
// A non default rule with the same path could have failed due to it needing input.
// TODO: Cleanup rule evaluation bookmarking.
} else {
if let Some(rules) = self.default_rules.get(&path) {
matched = true;
for (r, _) in rules.clone() {
if !self.processed.contains(&r) {
let module = self.get_rule_module(&r)?;
let prev_module = self.set_current_module(Some(module))?;
self.eval_default_rule(&r)?;
self.set_current_module(prev_module)?;
}
}

Check warning

Code scanning / clippy

this else { if .. } block can be collapsed Warning

this else { if .. } block can be collapsed
}
}

Expand Down Expand Up @@ -2660,6 +2736,12 @@

fn mark_processed(&mut self, path: &[&str]) -> Result<()> {
let obj = self.processed_paths.make_or_get_value_mut(path)?;

if self.is_constant_folding && obj == &Value::Undefined {
// If constant folding, then do not register undefined values.
return Ok(());
}

if obj == &Value::Undefined {
*obj = Value::new_object();
}
Expand All @@ -2677,6 +2759,11 @@

// Handle input.
if name.text() == "input" {
if self.is_constant_folding {
// When constant folding, expressions cannot depend on input.
self.has_side_effects = true;
return Ok(Value::Undefined);
}
return Ok(Self::get_value_chained(self.input.clone(), fields));
}

Expand Down Expand Up @@ -3259,11 +3346,13 @@
}

if let Some((_, r)) = conflict {
self.rule_value_conflict_error = true;
bail!(refr.span().error(&format!(
"rule conflicts with the following rule:\n{}",
r.span().message("", "defined here")
)));
}

self.rule_values
.insert(path.to_vec(), (value.clone(), refr.clone()));

Expand All @@ -3285,19 +3374,31 @@

let value = self.eval_rule_bodies(ctx, span, rule_body)?;
let package_components = self.eval_rule_ref(&module.package.refr)?;

if value != Value::Undefined {
for (path, value) in value.as_object()? {
let mut full_path = package_components.clone();
full_path.append(&mut path.as_array()?.clone());
self.check_rule_path(refr, &full_path, value, is_set)?;
/*if self.is_constant_folding && self.has_side_effects {
// Do not update rule value.
} else {*/
self.update_rule_value(span, full_path, value.clone(), is_set)?;
//}
}
} else if is_set {
if let Ok(mut comps) = self.eval_rule_ref(refr) {
let mut full_path = package_components;
full_path.append(&mut comps);
self.update_rule_value(span, full_path, Value::new_set(), true)?;
if self.is_constant_folding && self.has_side_effects {
// Do not create empty set if constant folding and rule failed to evaluate.
} else {
self.update_rule_value(
span,
full_path,
Value::new_set(),
true,
)?;
}
}
} else if is_object {
// Fetch the rule, ignoring the key.
Expand All @@ -3314,7 +3415,15 @@
}
}
}
self.processed.insert(rule.clone());

if self.is_constant_folding {
if value != Value::Undefined && !self.has_side_effects {
// When constant folding, record only those rules that have successfully evaluated.
self.processed.insert(rule.clone());
}
} else {
self.processed.insert(rule.clone());
}
}
RuleHead::Func {
refr, args, assign, ..
Expand Down Expand Up @@ -3392,7 +3501,10 @@
let scopes = core::mem::take(&mut self.scopes);
let prev_module = self.set_current_module(Some(module.clone()))?;

let hse = self.has_side_effects;
self.has_side_effects = false;
let res = self.eval_rule_impl(module, rule);
self.has_side_effects = hse || self.has_side_effects;

self.set_current_module(prev_module)?;
self.scopes = scopes;
Expand Down Expand Up @@ -3485,6 +3597,34 @@
}
}

pub fn constant_fold(&mut self) -> Result<()> {
self.is_constant_folding = true;
self.optimized = None;
self.clean_internal_evaluation_state();
self.rule_value_conflict_error = false;
// TODO: Maybe use scheduler information to determine which rules to evaluate and in which order.
for module in &self.modules.clone() {
for rule in &module.policy {
if let Rule::Spec { head, .. } = rule.as_ref() {
if matches!(&head, RuleHead::Compr { .. } | RuleHead::Set { .. }) {
if let Err(e) = self.eval_rule(module, rule) {
if self.rule_value_conflict_error {
bail!(e);
}
}
}
}
}
}
self.is_constant_folding = false;
self.optimized = Some(Optimized {
data: self.data.clone(),
processed: self.processed.clone(),
processed_paths: self.processed_paths.clone(),
});
Ok(())
}

fn get_rule_path_components(mut refr: &Ref<Expr>) -> Result<Vec<Rc<str>>> {
let mut components: Vec<Rc<str>> = vec![];
loop {
Expand Down
Loading