diff --git a/brilirs/src/basic_block.rs b/brilirs/src/basic_block.rs index 4845d8cd0..d045d6ed4 100644 --- a/brilirs/src/basic_block.rs +++ b/brilirs/src/basic_block.rs @@ -1,40 +1,64 @@ use std::collections::HashMap; -// A program composed of basic blocks. -// (BB index of main program, list of BBs, mapping of label -> BB index) -pub type BBProgram = (Option, Vec, HashMap); +pub struct Function { + pub args: Vec, + pub return_type: Option, + pub blocks: Vec, -#[derive(Debug)] -pub struct BasicBlock { - pub instrs: Vec, - pub exit: Vec, + // Map from label to the index of the block that is the target of the label. + pub label_index: HashMap, } -impl BasicBlock { - fn new() -> BasicBlock { - BasicBlock { - instrs: Vec::new(), - exit: Vec::new(), - } +impl Function { + pub fn new(f: bril_rs::Function) -> Function { + let mut func = Function { + args: f.args.clone(), + return_type: f.return_type.clone(), + blocks: vec![], + label_index: HashMap::new(), + }; + func.add_blocks(f.instrs); + func.build_cfg(); + func } -} -pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { - let mut main_fn = None; - let mut blocks = Vec::new(); - let mut labels = HashMap::new(); + fn build_cfg(&mut self) { + let last_idx = self.blocks.len() - 1; + for (i, block) in self.blocks.iter_mut().enumerate() { + // If we're before the last block + if i < last_idx { + // Get the last instruction + let last_instr: &bril_rs::Code = block.instrs.last().unwrap(); + if let bril_rs::Code::Instruction(bril_rs::Instruction::Effect { op, labels, .. }) = + last_instr + { + if let bril_rs::EffectOps::Jump | bril_rs::EffectOps::Branch = op { + for l in labels { + block.exit.push( + *self + .label_index + .get(l) + .expect(&format!("No label {} found.", &l)), + ); + } + } + } else { + block.exit.push(i + 1); + } + } + } + } - let mut bb_helper = |func: bril_rs::Function| -> usize { + fn add_blocks(&mut self, instrs: Vec) { let mut curr_block = BasicBlock::new(); - let root_block = blocks.len(); let mut curr_label = None; - for instr in func.instrs.into_iter() { + for instr in instrs { match instr { bril_rs::Code::Label { ref label } => { if !curr_block.instrs.is_empty() { - blocks.push(curr_block); + self.blocks.push(curr_block); if let Some(old_label) = curr_label { - labels.insert(old_label, blocks.len() - 1); + self.label_index.insert(old_label, self.blocks.len() - 1); } curr_block = BasicBlock::new(); } @@ -46,9 +70,9 @@ pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { || op == bril_rs::EffectOps::Return => { curr_block.instrs.push(instr); - blocks.push(curr_block); + self.blocks.push(curr_block); if let Some(l) = curr_label { - labels.insert(l, blocks.len() - 1); + self.label_index.insert(l, self.blocks.len() - 1); curr_label = None; } curr_block = BasicBlock::new(); @@ -58,24 +82,55 @@ pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { } } } - if !curr_block.instrs.is_empty() { - blocks.push(curr_block); + // If we are here, the function ends without an explicit ret. To make + // processing easier, push a Return op onto the last block. + curr_block.instrs.push(RET.clone()); + self.blocks.push(curr_block); if let Some(l) = curr_label { - labels.insert(l, blocks.len() - 1); + self.label_index.insert(l, self.blocks.len() - 1); } } + } +} - root_block - }; +// A program represented as basic blocks. +pub struct BBProgram { + pub func_index: HashMap, +} - for func in prog.functions.into_iter() { - let func_name = func.name.clone(); - let func_block = bb_helper(func); - if func_name == "main" { - main_fn = Some(func_block); +impl BBProgram { + pub fn new(prog: bril_rs::Program) -> BBProgram { + let mut bbprog = BBProgram { + func_index: HashMap::new(), + }; + for func in prog.functions { + bbprog + .func_index + .insert(func.name.clone(), Function::new(func)); } + bbprog } +} + +#[derive(Debug)] +pub struct BasicBlock { + pub instrs: Vec, + pub exit: Vec, +} - (main_fn, blocks, labels) +impl BasicBlock { + fn new() -> BasicBlock { + BasicBlock { + instrs: Vec::new(), + exit: Vec::new(), + } + } } + +const RET: bril_rs::Code = bril_rs::Code::Instruction(bril_rs::Instruction::Effect { + op: bril_rs::EffectOps::Return, + args: vec![], + funcs: vec![], + labels: vec![], +}); diff --git a/brilirs/src/cfg.rs b/brilirs/src/cfg.rs deleted file mode 100644 index 9abc36591..000000000 --- a/brilirs/src/cfg.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::basic_block::BasicBlock; - -use std::collections::HashMap; - -type CFG = Vec; - -pub fn build_cfg(mut blocks: Vec, label_to_block_idx: &HashMap) -> CFG { - let last_idx = blocks.len() - 1; - for (i, block) in blocks.iter_mut().enumerate() { - // If we're before the last block - if i < last_idx { - // Get the last instruction - let last_instr: &bril_rs::Code = block.instrs.last().unwrap(); - if let bril_rs::Code::Instruction(bril_rs::Instruction::Effect { op, labels, .. }) = - last_instr - { - match op { - bril_rs::EffectOps::Jump | bril_rs::EffectOps::Branch => { - for l in labels { - block.exit.push(label_to_block_idx[l]); - } - } - bril_rs::EffectOps::Return => {} - // TODO(yati): Do all effect ops end a BB? - _ => {} - } - } else { - block.exit.push(i + 1); - } - } - } - - blocks -} diff --git a/brilirs/src/interp.rs b/brilirs/src/interp.rs index d2db88e9f..8226fbd8d 100644 --- a/brilirs/src/interp.rs +++ b/brilirs/src/interp.rs @@ -2,13 +2,15 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::fmt; -use crate::basic_block::{BBProgram, BasicBlock}; +use crate::basic_block::{BBProgram, BasicBlock, Function}; #[derive(Debug)] pub enum InterpError { BadJsonInt, BadJsonBool, NoMainFunction, + FuncNotFound(String), + NoRetValForfunc(String), BadNumArgs(usize, usize), // (expected, actual) BadNumLabels(usize, usize), // (expected, actual) VarNotFound(String), @@ -16,6 +18,7 @@ pub enum InterpError { LabelNotFound(String), BadValueType(bril_rs::Type, bril_rs::Type), // (expected, actual) IoError(Box), + BadCall(String, String), // (func name, reason). } fn check_asmt_type(expected: &bril_rs::Type, actual: &bril_rs::Type) -> Result<(), InterpError> { @@ -151,12 +154,15 @@ impl TryFrom<&Value> for f64 { } #[allow(clippy::float_cmp)] -fn execute_value_op( +fn execute_value_op( + prog: &BBProgram, op: &bril_rs::ValueOps, dest: &str, op_type: &bril_rs::Type, args: &[String], + funcs: &[String], value_store: &mut HashMap, + out: &mut W, ) -> Result<(), InterpError> { use bril_rs::ValueOps::*; match *op { @@ -271,7 +277,32 @@ fn execute_value_op( let args = get_args::(value_store, 2, args)?; value_store.insert(String::from(dest), Value::Bool(args[0] >= args[1])); } - Call => unreachable!(), // TODO(yati): Why is Call a ValueOp as well? + Call => { + assert!(funcs.len() == 1); + let func_info = prog + .func_index + .get(&funcs[0]) + .ok_or(InterpError::FuncNotFound(funcs[0].clone()))?; + + check_asmt_type( + func_info.return_type.as_ref().ok_or(InterpError::BadCall( + String::from(&funcs[0]), + String::from( + "Function does not return a value, but used on the right side of an assignment", + ), + ))?, + op_type, + )?; + + let vars = make_func_args(&funcs[0], func_info, args, value_store)?; + if let Some(val) = execute_func(&prog, &funcs[0], vars, out)? { + check_asmt_type(&val.get_type(), op_type)?; + value_store.insert(String::from(dest), val); + } else { + // This is a value-op call, so the target func must return a result. + return Err(InterpError::NoRetValForfunc(funcs[0].clone())); + } + } Phi | Alloc | Load | PtrAdd => unimplemented!(), } Ok(()) @@ -293,17 +324,56 @@ fn check_num_labels(expected: usize, labels: &[String]) -> Result<(), InterpErro } } -// Returns whether the program should continue running (i.e., if a Return was -// *not* executed). +// Returns a map from function parameter names to values of the call arguments +// that are bound to those parameters. +fn make_func_args( + func_name: &str, + func: &Function, + call_args: &[String], + vars: &HashMap, +) -> Result, InterpError> { + if func.args.len() != call_args.len() { + return Err(InterpError::BadCall( + String::from(func_name), + format!( + "Expected {} parameters, tried to pass {} args", + func.args.len(), + call_args.len() + ), + )); + } + let vals = get_values(vars, call_args.len(), call_args)?; + let mut args = HashMap::new(); + for (i, arg) in func.args.iter().enumerate() { + check_asmt_type(&arg.arg_type, &vals[i].get_type())?; + args.insert(arg.name.clone(), vals[i].clone()); + } + Ok(args) +} + +// Result of executing an effect operation. +enum EffectResult { + // Return from the current function without any value. + Return, + + // Return a given value from the current function. + ReturnWithVal(Value), + + // Continue execution of the current function. + Continue, +} + fn execute_effect_op( + prog: &BBProgram, op: &bril_rs::EffectOps, args: &[String], labels: &[String], + funcs: &[String], curr_block: &BasicBlock, value_store: &HashMap, - mut out: T, + out: &mut T, next_block_idx: &mut Option, -) -> Result { +) -> Result { use bril_rs::EffectOps::*; match op { Jump => { @@ -320,36 +390,55 @@ fn execute_effect_op( Return => { out.flush().map_err(|e| InterpError::IoError(Box::new(e)))?; // NOTE: This only works so long as `main` is the only function - return Ok(false); + if args.is_empty() { + return Ok(EffectResult::Return); + } + let retval = value_store + .get(&args[0]) + .ok_or(InterpError::VarNotFound(args[0].clone()))?; + return Ok(EffectResult::ReturnWithVal(retval.clone())); } Print => { + let vals = get_values(value_store, args.len(), args)?; writeln!( out, "{}", - args + vals .iter() - .map(|a| format!("{}", value_store[a])) + .map(|v| format!("{}", v)) .collect::>() .join(", ") ) .map_err(|e| InterpError::IoError(Box::new(e)))?; } Nop => {} - Call => unreachable!(), + Call => { + assert!(funcs.len() == 1); + let func = prog + .func_index + .get(&funcs[0]) + .ok_or(InterpError::FuncNotFound(funcs[0].clone()))?; + let vars = make_func_args(&funcs[0], func, args, value_store)?; + execute_func(&prog, &funcs[0], vars, out)?; + } Store | Free | Speculate | Commit | Guard => unimplemented!(), } - Ok(true) + Ok(EffectResult::Continue) } -pub fn execute(prog: BBProgram, mut out: T) -> Result<(), InterpError> { - let (main_fn, blocks, _labels) = prog; - let mut curr_block_idx: usize = main_fn.ok_or(InterpError::NoMainFunction)?; - - // Map from variable name to value. - let mut value_store: HashMap = HashMap::new(); - +fn execute_func( + prog: &BBProgram, + func: &str, + mut vars: HashMap, + out: &mut T, +) -> Result, InterpError> { + let f = prog + .func_index + .get(func) + .ok_or(InterpError::FuncNotFound(String::from(func)))?; + let mut curr_block_idx = 0; loop { - let curr_block = &blocks[curr_block_idx]; + let curr_block = &f.blocks[curr_block_idx]; let curr_instrs = &curr_block.instrs; let mut next_block_idx = if curr_block.exit.len() == 1 { Some(curr_block.exit[0]) @@ -367,44 +456,58 @@ pub fn execute(prog: BBProgram, mut out: T) -> Result<(), Int value, } => { check_asmt_type(const_type, &value.get_type())?; - value_store.insert(dest.clone(), Value::from(value)); + vars.insert(dest.clone(), Value::from(value)); } bril_rs::Instruction::Value { op, dest, op_type, args, + funcs, .. } => { - execute_value_op(op, dest, op_type, args, &mut value_store)?; + execute_value_op(&prog, op, dest, op_type, args, funcs, &mut vars, out)?; } bril_rs::Instruction::Effect { - op, args, labels, .. + op, + args, + labels, + funcs, + .. } => { - let should_continue = execute_effect_op( + match execute_effect_op( + prog, op, args, labels, + funcs, &curr_block, - &value_store, - &mut out, + &vars, + out, &mut next_block_idx, - )?; - - // TODO(yati): Correct only when main is the only function. - if !should_continue { - return Ok(()); - } + )? { + EffectResult::Continue => {} + EffectResult::Return => { + return Ok(None); + } + EffectResult::ReturnWithVal(val) => { + return Ok(Some(val)); + } + }; } } } } - if let Some(idx) = next_block_idx { curr_block_idx = idx; } else { out.flush().map_err(|e| InterpError::IoError(Box::new(e)))?; - return Ok(()); + return Ok(None); } } } + +pub fn execute(prog: BBProgram, out: &mut T) -> Result<(), InterpError> { + // Ignore return value of @main. + execute_func(&prog, "main", HashMap::new(), out).map(|_| ()) +} diff --git a/brilirs/src/lib.rs b/brilirs/src/lib.rs index a2f5aefbd..62813bc4a 100644 --- a/brilirs/src/lib.rs +++ b/brilirs/src/lib.rs @@ -1,5 +1,4 @@ mod basic_block; -mod cfg; mod interp; #[macro_use] @@ -11,11 +10,11 @@ extern crate serde; extern crate serde_derive; extern crate serde_json; -pub fn run_input(input: Box, out: T) { +pub fn run_input(input: Box, mut out: T) { let prog = bril_rs::load_program_from_read(input); - let (main_idx, blocks, label_index) = basic_block::find_basic_blocks(prog); - let blocks = cfg::build_cfg(blocks, &label_index); - if let Err(e) = interp::execute((main_idx, blocks, label_index), out) { + println!("{:?}", &prog); + let bbprog = basic_block::BBProgram::new(prog); + if let Err(e) = interp::execute(bbprog, &mut out) { error!("{:?}", e); } } diff --git a/brilirs/testdata/call-with-args.json b/brilirs/testdata/call-with-args.json new file mode 100644 index 000000000..9d5ce2066 --- /dev/null +++ b/brilirs/testdata/call-with-args.json @@ -0,0 +1,88 @@ +{ + "functions": [ + { + "instrs": [ + { + "dest": "x", + "op": "const", + "type": "int", + "value": 2 + }, + { + "dest": "y", + "op": "const", + "type": "int", + "value": 2 + }, + { + "args": [ + "x", + "y" + ], + "dest": "z", + "funcs": [ + "add2" + ], + "op": "call", + "type": "int" + }, + { + "args": [ + "y" + ], + "op": "print" + }, + { + "args": [ + "z" + ], + "op": "print" + } + ], + "name": "main" + }, + { + "args": [ + { + "name": "m", + "type": "int" + }, + { + "name": "n", + "type": "int" + } + ], + "instrs": [ + { + "args": [ + "m", + "n" + ], + "dest": "w", + "op": "add", + "type": "int" + }, + { + "dest": "y", + "op": "const", + "type": "int", + "value": 5 + }, + { + "args": [ + "w" + ], + "op": "print" + }, + { + "args": [ + "w" + ], + "op": "ret" + } + ], + "name": "add2", + "type": "int" + } + ] +} diff --git a/brilirs/testdata/call.json b/brilirs/testdata/call.json new file mode 100644 index 000000000..35970122e --- /dev/null +++ b/brilirs/testdata/call.json @@ -0,0 +1,47 @@ +{ + "functions": [ + { + "instrs": [ + { + "dest": "v", + "op": "const", + "type": "int", + "value": 2 + }, + { + "funcs": [ + "print4" + ], + "op": "call" + }, + { + "args": [ + "v" + ], + "op": "print" + } + ], + "name": "main" + }, + { + "instrs": [ + { + "dest": "v", + "op": "const", + "type": "int", + "value": 4 + }, + { + "args": [ + "v" + ], + "op": "print" + }, + { + "op": "ret" + } + ], + "name": "print4" + } + ] +} diff --git a/brilirs/tests/interp_test.rs b/brilirs/tests/interp_test.rs index 93299bfbb..df9b5fa3b 100644 --- a/brilirs/tests/interp_test.rs +++ b/brilirs/tests/interp_test.rs @@ -56,4 +56,6 @@ interp_tests! { or: "./testdata/or.json", id: "./testdata/id.json", br: "./testdata/br.json", + call: "./testdata/call.json", + call_with_args: "./testdata/call-with-args.json", }