diff --git a/data/error_policies/parent_call.cas b/data/error_policies/call_casting.cas similarity index 100% rename from data/error_policies/parent_call.cas rename to data/error_policies/call_casting.cas diff --git a/data/error_policies/functions_no_term.cas b/data/error_policies/functions_no_term.cas new file mode 100644 index 00000000..f63985b6 --- /dev/null +++ b/data/error_policies/functions_no_term.cas @@ -0,0 +1,20 @@ +resource my_file { + fn read(domain source) { + allow(source, this, file, [ read open getattr ]); + other_read(source); + } + + fn other_read(domain source) { + allow(source, this, file, [ read open getattr ]); + third_read(source); + } + + fn third_read(domain source) { + allow(source, this, file, [ read open getattr ]); + read(source); + } +} + +domain my_domain { + my_file.read(this); // TODO: support 'this' as default argument +} diff --git a/data/error_policies/functions_recursion.cas b/data/error_policies/functions_recursion.cas new file mode 100644 index 00000000..690f90f1 --- /dev/null +++ b/data/error_policies/functions_recursion.cas @@ -0,0 +1,24 @@ +resource my_file { + fn read(domain source) { + allow(source, this, file, [ read open getattr ]); + other_read(source); + } + + fn other_read(domain source) { + allow(source, this, file, [ read open getattr ]); + third_read(source); + } + + fn third_read(domain source) { + allow(source, this, file, [ read open getattr ]); + read(source); + } + + fn term_read(domain source) { + allow(source, this, file, [ read open getattr ]); + } +} + +domain my_domain { + my_file.read(this); // TODO: support 'this' as default argument +} diff --git a/data/expected_cil/parent_call.cil b/data/expected_cil/call_casting.cil similarity index 100% rename from data/expected_cil/parent_call.cil rename to data/expected_cil/call_casting.cil diff --git a/data/policies/parent_call.cas b/data/policies/call_casting.cas similarity index 100% rename from data/policies/parent_call.cas rename to data/policies/call_casting.cas diff --git a/src/compile.rs b/src/compile.rs index 97f58f29..b86db9a0 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -520,6 +520,102 @@ pub fn determine_castable(functions: &mut FunctionMap, types: &TypeMap) -> u64 { num_changed } +pub fn initialize_terminated<'a>( + functions: &'a FunctionMap<'a>, +) -> (Vec, Vec>) { + let mut term_ret_vec: Vec = Vec::new(); + let mut nonterm_ret_vec: Vec = Vec::new(); + + for func in functions.values() { + let mut is_term = true; + + for call in func.original_body { + if let Statement::Call(call) = call { + match call.check_builtin() { + Some(_) => { + continue; + } + None => { + is_term = false; + break; + } + } + } + } + if is_term { + term_ret_vec.push(func.get_cil_name().clone()); + } else { + nonterm_ret_vec.push(func.clone()); + } + } + + (term_ret_vec, nonterm_ret_vec) +} + +pub fn search_for_recursion( + terminated_list: &mut Vec, + functions: &mut Vec, +) -> Result<(), CascadeErrors> { + let mut removed: u64 = 1; + while removed > 0 { + removed = 0; + for func in functions.clone().iter_mut() { + let mut is_term = true; + for call in func.original_body { + if let Statement::Call(call) = call { + let mut call_cil_name = call.get_cil_name(); + // If we are calling something with this it must be in the same class + // so hand place the class name + if call_cil_name.contains("this-") { + if let Some(class) = func.class { + call_cil_name = + call_cil_name.replace("this-", &(class.name.to_string() + "-")) + } else { + return Err(CascadeErrors::from(ErrorItem::make_compile_or_internal_error( + "Could not determine class for 'this.' function call", + Some(func.declaration_file), + call.get_name_range(), + "Perhaps you meant to place the function in a resource or domain?"))); + } + } + + if terminated_list.contains(&call_cil_name) { + continue; + } + is_term = false; + } + } + if is_term { + terminated_list.push(func.get_cil_name()); + removed += 1; + let index = functions + .iter() + .position(|x| *x.get_cil_name() == func.get_cil_name()) + .unwrap(); + functions.remove(index); + } + } + } + + if !functions.is_empty() { + let mut error: Option = None; + + for func in functions { + error = Some(add_or_create_compile_error( + error, + "Recursive Function call found", + func.declaration_file, + func.get_declaration_range().unwrap_or_default(), + "Calls cannot recursively call each other", + )); + } + // Unwrap is safe since we need to go through the loop above at least once + return Err(CascadeErrors::from(error.unwrap())); + } + + Ok(()) +} + pub fn prevalidate_functions( functions: &mut FunctionMap, types: &TypeMap, @@ -531,6 +627,26 @@ pub fn prevalidate_functions( num_changed = determine_castable(functions, types); } + let (mut terminated_functions, mut nonterm_functions) = initialize_terminated(functions); + + if terminated_functions.is_empty() && !nonterm_functions.is_empty() { + let mut error: Option = None; + + for func in functions.values() { + error = Some(add_or_create_compile_error( + error, + "No terminating call found", + func.declaration_file, + func.get_declaration_range().unwrap_or_default(), + "All function calls found in possible recursive loop", + )); + } + // Unwrap is safe since we need to go through the loop above at least once + return Err(CascadeErrors::from(error.unwrap())); + } + + search_for_recursion(&mut terminated_functions, &mut nonterm_functions)?; + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 3f3cf337..dd39f56a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1376,8 +1376,8 @@ mod tests { } #[test] - fn invalid_parent_call() { - error_policy_test!("parent_call.cas", 6, ErrorItem::Compile(_)); + fn invalid_call_casting() { + error_policy_test!("call_casting.cas", 6, ErrorItem::Compile(_)); } #[test] @@ -1386,9 +1386,9 @@ mod tests { } #[test] - fn valid_parent_call() { + fn valid_call_casting() { valid_policy_test( - "parent_call.cas", + "call_casting.cas", &[ "call bar-read (foo dom)", "call bar-foobar (foo dom)", @@ -1401,4 +1401,14 @@ mod tests { 0, ); } + + #[test] + fn invalid_function_recursion_test() { + error_policy_test!("functions_recursion.cas", 1, ErrorItem::Compile(_)); + } + + #[test] + fn invalid_function_noterm_test() { + error_policy_test!("functions_no_term.cas", 1, ErrorItem::Compile(_)); + } }