Skip to content

Commit

Permalink
moving helper functions to functions.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-sheets committed Mar 7, 2023
1 parent 93d48a9 commit 39bbcf3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 167 deletions.
169 changes: 5 additions & 164 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ use std::convert::TryFrom;
use std::ops::Range;

use crate::ast::{
get_all_func_calls, get_cil_name, Argument, CascadeString, Declaration, Expression, FuncCall,
LetBinding, Machine, Module, PolicyFile, Statement,
Argument, CascadeString, Declaration, Expression, FuncCall, LetBinding, Machine, Module,
PolicyFile, Statement,
};
use crate::constants;
use crate::context::{BindableObject, BlockType, Context as BlockContext};
use crate::error::{
add_or_create_compile_error, CascadeErrors, CompileError, ErrorItem, InternalError,
};
use crate::functions::{
determine_castable, initialize_castable, initialize_terminated, search_for_recursion,
ArgForValidation, FSContextType, FileSystemContextRule, FunctionArgument, FunctionClass,
FunctionInfo, FunctionMap, ValidatedCall, ValidatedStatement,
};
use crate::internal_rep::{
generate_sid_rules, validate_derive_args, Annotated,
AnnotationInfo, Associated, BoundTypeInfo, ClassList, Context, Sid, TypeInfo, TypeInstance,
TypeMap,
generate_sid_rules, validate_derive_args, Annotated, AnnotationInfo, Associated, BoundTypeInfo,
ClassList, Context, Sid, TypeInfo, TypeInstance, TypeMap,
};
use crate::machine::{MachineMap, ModuleMap, ValidatedMachine, ValidatedModule};
use crate::warning::{Warnings, WithWarnings};
Expand Down Expand Up @@ -476,165 +476,6 @@ pub fn validate_rules(statements: &BTreeSet<ValidatedStatement>) -> Result<(), C
errors.into_result(())
}

// Go through all functions and check if they are castable
// based only on their args
pub fn initialize_castable(functions: &mut FunctionMap, types: &TypeMap) {
for func in functions.values_mut() {
for arg in &func.args {
if arg.param_type.is_associated_resource(types) {
func.is_castable = false;

// If we have found one associated resource
// we can continue
continue;
}
}
}
}

// Go through all of the functions and check if they are castable
// base on functions they call.
pub fn determine_castable(functions: &mut FunctionMap, types: &TypeMap) -> u64 {
let mut num_changed: u64 = 0;
// We need tmp_functions to avoid a immutable borrow after a mutable one.
let tmp_functions = functions.clone();
'outer: for func in functions.values_mut() {
// If we are already false there is no reason to check our called functions.
if !func.is_castable {
continue;
}
for call in func.original_body {
if let Statement::Call(call) = call {
if let Some(inner_func) = tmp_functions.get(&call.get_cil_name()) {
if !inner_func.is_castable {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
}
for arg in &call.args {
if let Argument::Var(arg) = &arg.0 {
// Need to special case this.*
if arg.to_string().contains("this.") {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
if let Some(ti) = types.get(arg.as_ref()) {
if ti.is_associated_resource(types) {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
}
}
}
}
}
}
num_changed
}

pub fn initialize_terminated<'a>(functions: &'a FunctionMap<'a>) -> (Vec<String>, Vec<String>) {
let mut term_ret_vec: Vec<String> = Vec::new();
let mut nonterm_ret_vec: Vec<String> = Vec::new();

for func in functions.values() {
let mut is_term = true;

let func_calls = get_all_func_calls(func.original_body.to_vec());

for call in func_calls {
match call.check_builtin() {
Some(_) => {
continue;
}
None => {
is_term = false;
break;
}
}
}
if is_term {
term_ret_vec.push(func.get_cil_name().clone());
} else {
nonterm_ret_vec.push(func.get_cil_name().clone());
}
}

(term_ret_vec, nonterm_ret_vec)
}

pub fn search_for_recursion(
terminated_list: &mut Vec<String>,
functions: &mut Vec<String>,
function_map: &FunctionMap,
) -> Result<(), CascadeErrors> {
let mut removed: u64 = 1;
while removed > 0 {
removed = 0;
for func in functions.clone().iter() {
let mut is_term = false;
if let Some(function_info) = function_map.get(func) {
let func_calls = get_all_func_calls(function_info.original_body.to_vec());
for call in func_calls {
let mut call_cil_name = call.get_cil_name();
// If we are calling something with this it must be in the same class
// so hand place the class name
if call_cil_name.contains("this-") {
if let FunctionClass::Type(class) = function_info.class {
call_cil_name = get_cil_name(Some(&class.name), &call.name)
} else {
return Err(CascadeErrors::from(
ErrorItem::make_compile_or_internal_error(
"Could not determine class for 'this.' function call",
Some(function_info.declaration_file),
call.get_name_range(),
"Perhaps you meant to place the function in a resource or domain?",
),
));
}
}

if terminated_list.contains(&call_cil_name) {
is_term = true;
break;
}
}
if is_term {
terminated_list.push(function_info.get_cil_name());
removed += 1;
let index = functions
.iter()
.position(|x| *x == function_info.get_cil_name())
.unwrap();
functions.remove(index);
}
}
}
}

if !functions.is_empty() {
let mut error: Option<CompileError> = None;

for func in functions {
if let Some(function_info) = function_map.get(func) {
error = Some(add_or_create_compile_error(
error,
"Recursive Function call found",
function_info.declaration_file,
function_info.get_declaration_range().unwrap_or_default(),
"Calls cannot recursively call each other",
));
}
}
// Unwrap is safe since we need to go through the loop above at least once
return Err(CascadeErrors::from(error.unwrap()));
}

Ok(())
}

pub fn prevalidate_functions(
functions: &mut FunctionMap,
types: &TypeMap,
Expand Down
167 changes: 164 additions & 3 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use codespan_reporting::files::SimpleFile;

use crate::alias_map::{AliasMap, Declared};
use crate::ast::{
get_cil_name, Annotation, Argument, BuiltIns, CascadeString, DeclaredArgument, FuncCall,
FuncDecl, IpAddr, Port, Statement,
get_all_func_calls, get_cil_name, Annotation, Argument, BuiltIns, CascadeString,
DeclaredArgument, FuncCall, FuncDecl, IpAddr, Port, Statement,
};
use crate::constants;
use crate::context::{BlockType, Context as BlockContext};
use crate::error::{CascadeErrors, CompileError, ErrorItem, InternalError};
use crate::error::{
add_or_create_compile_error, CascadeErrors, CompileError, ErrorItem, InternalError,
};
use crate::internal_rep::{
convert_class_name_if_this, type_name_from_string, typeinfo_from_string, Annotated,
AnnotationInfo, BoundTypeInfo, ClassList, Context, TypeInfo, TypeInstance, TypeMap,
Expand Down Expand Up @@ -2704,6 +2706,165 @@ impl<'a> From<&'a FunctionArgument<'a>> for ExpectedArgInfo<'a, '_> {
}
}

// Go through all functions and check if they are castable
// based only on their args
pub fn initialize_castable(functions: &mut FunctionMap, types: &TypeMap) {
for func in functions.values_mut() {
for arg in &func.args {
if arg.param_type.is_associated_resource(types) {
func.is_castable = false;

// If we have found one associated resource
// we can continue
continue;
}
}
}
}

// Go through all of the functions and check if they are castable
// base on functions they call.
pub fn determine_castable(functions: &mut FunctionMap, types: &TypeMap) -> u64 {
let mut num_changed: u64 = 0;
// We need tmp_functions to avoid a immutable borrow after a mutable one.
let tmp_functions = functions.clone();
'outer: for func in functions.values_mut() {
// If we are already false there is no reason to check our called functions.
if !func.is_castable {
continue;
}
for call in func.original_body {
if let Statement::Call(call) = call {
if let Some(inner_func) = tmp_functions.get(&call.get_cil_name()) {
if !inner_func.is_castable {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
}
for arg in &call.args {
if let Argument::Var(arg) = &arg.0 {
// Need to special case this.*
if arg.to_string().contains("this.") {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
if let Some(ti) = types.get(arg.as_ref()) {
if ti.is_associated_resource(types) {
num_changed += 1;
func.is_castable = false;
continue 'outer;
}
}
}
}
}
}
}
num_changed
}

pub fn initialize_terminated<'a>(functions: &'a FunctionMap<'a>) -> (Vec<String>, Vec<String>) {
let mut term_ret_vec: Vec<String> = Vec::new();
let mut nonterm_ret_vec: Vec<String> = Vec::new();

for func in functions.values() {
let mut is_term = true;

let func_calls = get_all_func_calls(func.original_body.to_vec());

for call in func_calls {
match call.check_builtin() {
Some(_) => {
continue;
}
None => {
is_term = false;
break;
}
}
}
if is_term {
term_ret_vec.push(func.get_cil_name().clone());
} else {
nonterm_ret_vec.push(func.get_cil_name().clone());
}
}

(term_ret_vec, nonterm_ret_vec)
}

pub fn search_for_recursion(
terminated_list: &mut Vec<String>,
functions: &mut Vec<String>,
function_map: &FunctionMap,
) -> Result<(), CascadeErrors> {
let mut removed: u64 = 1;
while removed > 0 {
removed = 0;
for func in functions.clone().iter() {
let mut is_term = false;
if let Some(function_info) = function_map.get(func) {
let func_calls = get_all_func_calls(function_info.original_body.to_vec());
for call in func_calls {
let mut call_cil_name = call.get_cil_name();
// If we are calling something with this it must be in the same class
// so hand place the class name
if call_cil_name.contains("this-") {
if let FunctionClass::Type(class) = function_info.class {
call_cil_name = get_cil_name(Some(&class.name), &call.name)
} else {
return Err(CascadeErrors::from(
ErrorItem::make_compile_or_internal_error(
"Could not determine class for 'this.' function call",
Some(function_info.declaration_file),
call.get_name_range(),
"Perhaps you meant to place the function in a resource or domain?",
),
));
}
}

if terminated_list.contains(&call_cil_name) {
is_term = true;
break;
}
}
if is_term {
terminated_list.push(function_info.get_cil_name());
removed += 1;
let index = functions
.iter()
.position(|x| *x == function_info.get_cil_name())
.unwrap();
functions.remove(index);
}
}
}
}

if !functions.is_empty() {
let mut error: Option<CompileError> = None;

for func in functions {
if let Some(function_info) = function_map.get(func) {
error = Some(add_or_create_compile_error(
error,
"Recursive Function call found",
function_info.declaration_file,
function_info.get_declaration_range().unwrap_or_default(),
"Calls cannot recursively call each other",
));
}
}
// Unwrap is safe since we need to go through the loop above at least once
return Err(CascadeErrors::from(error.unwrap()));
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 39bbcf3

Please sign in to comment.