diff --git a/src/compile.rs b/src/compile.rs index bf5b6d6a..298be6ce 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -8,8 +8,8 @@ use std::convert::TryFrom; use std::ops::Range; use crate::ast::{ - get_all_func_calls, get_cil_name, Argument, CascadeString, Declaration, Expression, FuncCall, - LetBinding, Machine, Module, PolicyFile, Statement, + Argument, CascadeString, Declaration, Expression, FuncCall, LetBinding, Machine, Module, + PolicyFile, Statement, }; use crate::constants; use crate::context::{BindableObject, BlockType, Context as BlockContext}; @@ -17,13 +17,13 @@ use crate::error::{ add_or_create_compile_error, CascadeErrors, CompileError, ErrorItem, InternalError, }; use crate::functions::{ + determine_castable, initialize_castable, initialize_terminated, search_for_recursion, ArgForValidation, FSContextType, FileSystemContextRule, FunctionArgument, FunctionClass, FunctionInfo, FunctionMap, ValidatedCall, ValidatedStatement, }; use crate::internal_rep::{ - generate_sid_rules, validate_derive_args, Annotated, - AnnotationInfo, Associated, BoundTypeInfo, ClassList, Context, Sid, TypeInfo, TypeInstance, - TypeMap, + generate_sid_rules, validate_derive_args, Annotated, AnnotationInfo, Associated, BoundTypeInfo, + ClassList, Context, Sid, TypeInfo, TypeInstance, TypeMap, }; use crate::machine::{MachineMap, ModuleMap, ValidatedMachine, ValidatedModule}; use crate::warning::{Warnings, WithWarnings}; @@ -476,165 +476,6 @@ pub fn validate_rules(statements: &BTreeSet) -> Result<(), C errors.into_result(()) } -// Go through all functions and check if they are castable -// based only on their args -pub fn initialize_castable(functions: &mut FunctionMap, types: &TypeMap) { - for func in functions.values_mut() { - for arg in &func.args { - if arg.param_type.is_associated_resource(types) { - func.is_castable = false; - - // If we have found one associated resource - // we can continue - continue; - } - } - } -} - -// Go through all of the functions and check if they are castable -// base on functions they call. -pub fn determine_castable(functions: &mut FunctionMap, types: &TypeMap) -> u64 { - let mut num_changed: u64 = 0; - // We need tmp_functions to avoid a immutable borrow after a mutable one. - let tmp_functions = functions.clone(); - 'outer: for func in functions.values_mut() { - // If we are already false there is no reason to check our called functions. - if !func.is_castable { - continue; - } - for call in func.original_body { - if let Statement::Call(call) = call { - if let Some(inner_func) = tmp_functions.get(&call.get_cil_name()) { - if !inner_func.is_castable { - num_changed += 1; - func.is_castable = false; - continue 'outer; - } - } - for arg in &call.args { - if let Argument::Var(arg) = &arg.0 { - // Need to special case this.* - if arg.to_string().contains("this.") { - num_changed += 1; - func.is_castable = false; - continue 'outer; - } - if let Some(ti) = types.get(arg.as_ref()) { - if ti.is_associated_resource(types) { - num_changed += 1; - func.is_castable = false; - continue 'outer; - } - } - } - } - } - } - } - num_changed -} - -pub fn initialize_terminated<'a>(functions: &'a FunctionMap<'a>) -> (Vec, Vec) { - let mut term_ret_vec: Vec = Vec::new(); - let mut nonterm_ret_vec: Vec = Vec::new(); - - for func in functions.values() { - let mut is_term = true; - - let func_calls = get_all_func_calls(func.original_body.to_vec()); - - for call in func_calls { - match call.check_builtin() { - Some(_) => { - continue; - } - None => { - is_term = false; - break; - } - } - } - if is_term { - term_ret_vec.push(func.get_cil_name().clone()); - } else { - nonterm_ret_vec.push(func.get_cil_name().clone()); - } - } - - (term_ret_vec, nonterm_ret_vec) -} - -pub fn search_for_recursion( - terminated_list: &mut Vec, - functions: &mut Vec, - function_map: &FunctionMap, -) -> Result<(), CascadeErrors> { - let mut removed: u64 = 1; - while removed > 0 { - removed = 0; - for func in functions.clone().iter() { - let mut is_term = false; - if let Some(function_info) = function_map.get(func) { - let func_calls = get_all_func_calls(function_info.original_body.to_vec()); - for call in func_calls { - let mut call_cil_name = call.get_cil_name(); - // If we are calling something with this it must be in the same class - // so hand place the class name - if call_cil_name.contains("this-") { - if let FunctionClass::Type(class) = function_info.class { - call_cil_name = get_cil_name(Some(&class.name), &call.name) - } else { - return Err(CascadeErrors::from( - ErrorItem::make_compile_or_internal_error( - "Could not determine class for 'this.' function call", - Some(function_info.declaration_file), - call.get_name_range(), - "Perhaps you meant to place the function in a resource or domain?", - ), - )); - } - } - - if terminated_list.contains(&call_cil_name) { - is_term = true; - break; - } - } - if is_term { - terminated_list.push(function_info.get_cil_name()); - removed += 1; - let index = functions - .iter() - .position(|x| *x == function_info.get_cil_name()) - .unwrap(); - functions.remove(index); - } - } - } - } - - if !functions.is_empty() { - let mut error: Option = None; - - for func in functions { - if let Some(function_info) = function_map.get(func) { - error = Some(add_or_create_compile_error( - error, - "Recursive Function call found", - function_info.declaration_file, - function_info.get_declaration_range().unwrap_or_default(), - "Calls cannot recursively call each other", - )); - } - } - // Unwrap is safe since we need to go through the loop above at least once - return Err(CascadeErrors::from(error.unwrap())); - } - - Ok(()) -} - pub fn prevalidate_functions( functions: &mut FunctionMap, types: &TypeMap, diff --git a/src/functions.rs b/src/functions.rs index cc564b08..e25af7a5 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -14,12 +14,14 @@ use codespan_reporting::files::SimpleFile; use crate::alias_map::{AliasMap, Declared}; use crate::ast::{ - get_cil_name, Annotation, Argument, BuiltIns, CascadeString, DeclaredArgument, FuncCall, - FuncDecl, IpAddr, Port, Statement, + get_all_func_calls, get_cil_name, Annotation, Argument, BuiltIns, CascadeString, + DeclaredArgument, FuncCall, FuncDecl, IpAddr, Port, Statement, }; use crate::constants; use crate::context::{BlockType, Context as BlockContext}; -use crate::error::{CascadeErrors, CompileError, ErrorItem, InternalError}; +use crate::error::{ + add_or_create_compile_error, CascadeErrors, CompileError, ErrorItem, InternalError, +}; use crate::internal_rep::{ convert_class_name_if_this, type_name_from_string, typeinfo_from_string, Annotated, AnnotationInfo, BoundTypeInfo, ClassList, Context, TypeInfo, TypeInstance, TypeMap, @@ -2704,6 +2706,165 @@ impl<'a> From<&'a FunctionArgument<'a>> for ExpectedArgInfo<'a, '_> { } } +// Go through all functions and check if they are castable +// based only on their args +pub fn initialize_castable(functions: &mut FunctionMap, types: &TypeMap) { + for func in functions.values_mut() { + for arg in &func.args { + if arg.param_type.is_associated_resource(types) { + func.is_castable = false; + + // If we have found one associated resource + // we can continue + continue; + } + } + } +} + +// Go through all of the functions and check if they are castable +// base on functions they call. +pub fn determine_castable(functions: &mut FunctionMap, types: &TypeMap) -> u64 { + let mut num_changed: u64 = 0; + // We need tmp_functions to avoid a immutable borrow after a mutable one. + let tmp_functions = functions.clone(); + 'outer: for func in functions.values_mut() { + // If we are already false there is no reason to check our called functions. + if !func.is_castable { + continue; + } + for call in func.original_body { + if let Statement::Call(call) = call { + if let Some(inner_func) = tmp_functions.get(&call.get_cil_name()) { + if !inner_func.is_castable { + num_changed += 1; + func.is_castable = false; + continue 'outer; + } + } + for arg in &call.args { + if let Argument::Var(arg) = &arg.0 { + // Need to special case this.* + if arg.to_string().contains("this.") { + num_changed += 1; + func.is_castable = false; + continue 'outer; + } + if let Some(ti) = types.get(arg.as_ref()) { + if ti.is_associated_resource(types) { + num_changed += 1; + func.is_castable = false; + continue 'outer; + } + } + } + } + } + } + } + num_changed +} + +pub fn initialize_terminated<'a>(functions: &'a FunctionMap<'a>) -> (Vec, Vec) { + let mut term_ret_vec: Vec = Vec::new(); + let mut nonterm_ret_vec: Vec = Vec::new(); + + for func in functions.values() { + let mut is_term = true; + + let func_calls = get_all_func_calls(func.original_body.to_vec()); + + for call in func_calls { + match call.check_builtin() { + Some(_) => { + continue; + } + None => { + is_term = false; + break; + } + } + } + if is_term { + term_ret_vec.push(func.get_cil_name().clone()); + } else { + nonterm_ret_vec.push(func.get_cil_name().clone()); + } + } + + (term_ret_vec, nonterm_ret_vec) +} + +pub fn search_for_recursion( + terminated_list: &mut Vec, + functions: &mut Vec, + function_map: &FunctionMap, +) -> Result<(), CascadeErrors> { + let mut removed: u64 = 1; + while removed > 0 { + removed = 0; + for func in functions.clone().iter() { + let mut is_term = false; + if let Some(function_info) = function_map.get(func) { + let func_calls = get_all_func_calls(function_info.original_body.to_vec()); + for call in func_calls { + let mut call_cil_name = call.get_cil_name(); + // If we are calling something with this it must be in the same class + // so hand place the class name + if call_cil_name.contains("this-") { + if let FunctionClass::Type(class) = function_info.class { + call_cil_name = get_cil_name(Some(&class.name), &call.name) + } else { + return Err(CascadeErrors::from( + ErrorItem::make_compile_or_internal_error( + "Could not determine class for 'this.' function call", + Some(function_info.declaration_file), + call.get_name_range(), + "Perhaps you meant to place the function in a resource or domain?", + ), + )); + } + } + + if terminated_list.contains(&call_cil_name) { + is_term = true; + break; + } + } + if is_term { + terminated_list.push(function_info.get_cil_name()); + removed += 1; + let index = functions + .iter() + .position(|x| *x == function_info.get_cil_name()) + .unwrap(); + functions.remove(index); + } + } + } + } + + if !functions.is_empty() { + let mut error: Option = None; + + for func in functions { + if let Some(function_info) = function_map.get(func) { + error = Some(add_or_create_compile_error( + error, + "Recursive Function call found", + function_info.declaration_file, + function_info.get_declaration_range().unwrap_or_default(), + "Calls cannot recursively call each other", + )); + } + } + // Unwrap is safe since we need to go through the loop above at least once + return Err(CascadeErrors::from(error.unwrap())); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*;