From 4a910568d8eadd177b3b89f9a26b140d73f16311 Mon Sep 17 00:00:00 2001 From: Morgan Thomas Date: Tue, 9 Dec 2025 03:00:48 +0000 Subject: [PATCH 01/16] Block-level scoping (#92) and variable manager (#88) --- crates/lean_compiler/src/a_simplify_lang.rs | 217 ++++++++++++++-- .../src/b_compile_intermediate.rs | 242 +++++++++++------- crates/lean_compiler/src/grammar.pest | 3 + crates/lean_compiler/src/lang.rs | 30 ++- crates/lean_compiler/src/lib.rs | 4 +- .../src/parser/parsers/statement.rs | 12 + crates/lean_compiler/tests/test_compiler.rs | 26 +- crates/lean_vm/src/diagnostics/error.rs | 4 +- crates/lean_vm/src/execution/memory.rs | 2 +- crates/lean_vm/src/execution/tests.rs | 2 +- .../recursion_program.lean_lang | 7 +- 11 files changed, 400 insertions(+), 149 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 50f95b59..bef4e851 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -3,7 +3,7 @@ use crate::{ ir::HighLevelOperation, lang::{ AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Expression, Function, - Line, Program, SimpleExpr, Var, + Line, Program, SimpleExpr, Var, Context, Scope, }, }; use lean_vm::{SourceLineNumber, Table, TableT}; @@ -72,6 +72,7 @@ pub enum SimpleLine { value: SimpleExpr, arms: Vec>, // patterns = 0, 1, ... }, + ForwardDeclaration { var: Var }, Assignment { var: VarOrConstMallocAccess, operation: HighLevelOperation, @@ -148,6 +149,7 @@ pub enum SimpleLine { } pub fn simplify_program(mut program: Program) -> SimpleProgram { + check_program_scoping(&program); handle_inlined_functions(&mut program); handle_const_arguments(&mut program); let mut new_functions = BTreeMap::new(); @@ -187,6 +189,173 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { } } +/// Analyzes the program to verify that each variable is defined in each context where it is used. +fn check_program_scoping(program: &Program) { + for (_, function) in program.functions.iter() { + let mut scope = Scope { vars: BTreeSet::new() }; + for (arg, _) in function.arguments.iter() { + scope.vars.insert(arg.clone()); + } + let mut ctx = Context { scopes: vec![scope] }; + + check_block_scoping(&function.body, &mut ctx); + } +} + +/// Analyzes the block to verify that each variable is defined in each context where it is used. +fn check_block_scoping(block: &Vec, ctx: &mut Context) { + for line in block.iter() { + match line { + Line::ForwardDeclaration { var } => { + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!(!last_scope.vars.contains(var), "Variable declared multiple times in the same scope: {:?}", var); + last_scope.vars.insert(var.clone()); + }, + Line::Match { value, arms } => { + check_expr_scoping(value, ctx); + for (_, arm) in arms { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(arm, ctx); + ctx.scopes.pop(); + } + }, + Line::Assignment { var, value } => { + check_expr_scoping(value, ctx); + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!(!last_scope.vars.contains(var), "Variable declared multiple times in the same scope: {:?}", var); + last_scope.vars.insert(var.clone()); + }, + Line::ArrayAssign { array, index, value } => { + check_simple_expr_scoping(array, ctx); + check_expr_scoping(index, ctx); + check_expr_scoping(value, ctx); + }, + Line::Assert(boolean, _) => { + check_boolean_scoping(boolean, ctx); + }, + Line::IfCondition { condition, then_branch, else_branch, line_number: _ } => { + check_condition_scoping(condition, ctx); + for branch in [then_branch, else_branch] { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(branch, ctx); + ctx.scopes.pop(); + } + }, + Line::ForLoop { iterator, start, end, body, rev: _, unroll: _, line_number: _ } => { + check_expr_scoping(start, ctx); + check_expr_scoping(end, ctx); + let mut new_scope_vars = BTreeSet::new(); + new_scope_vars.insert(iterator.clone()); + ctx.scopes.push(Scope { vars: new_scope_vars }); + check_block_scoping(body, ctx); + ctx.scopes.pop(); + }, + Line::FunctionCall { function_name: _, args, return_data, line_number: _ } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + let last_scope = ctx.scopes.last_mut().unwrap(); + for var in return_data { + assert!(!last_scope.vars.contains(var), "Variable declared multiple times in the same scope: {:?}", var); + last_scope.vars.insert(var.clone()); + } + }, + Line::FunctionRet { return_data } => { + for expr in return_data { + check_expr_scoping(expr, ctx); + } + }, + Line::Precompile { table: _, args } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + }, + Line::Break | Line::Panic | Line::LocationReport { .. } => {}, + Line::Print { line_info: _, content } => { + for expr in content { + check_expr_scoping(expr, ctx); + } + }, + Line::MAlloc { var, size, vectorized: _, vectorized_len } => { + check_expr_scoping(size, ctx); + check_expr_scoping(vectorized_len, ctx); + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!(!last_scope.vars.contains(var), "Variable declared multiple times in the same scope: {:?}", var); + last_scope.vars.insert(var.clone()); + }, + Line::DecomposeBits { var, to_decompose } => { + for expr in to_decompose { + check_expr_scoping(expr, ctx); + } + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!(!last_scope.vars.contains(var), "Variable declared multiple times in the same scope: {:?}", var); + last_scope.vars.insert(var.clone()); + }, + Line::DecomposeCustom { args } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + }, + Line::PrivateInputStart { result } => { + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!(!last_scope.vars.contains(result), "Variable declared multiple times in the same scope: {result}"); + last_scope.vars.insert(result.clone()); + } + } + } +} + +/// Analyzes the expression to verify that each variable is defined in the given context. +fn check_expr_scoping(expr: &Expression, ctx: &Context) { + match expr { + Expression::Value(simple_expr) => { + check_simple_expr_scoping(simple_expr, ctx); + }, + Expression::ArrayAccess { array, index } => { + check_simple_expr_scoping(array, ctx); + check_expr_scoping(&*index, ctx); + }, + Expression::Binary { left, operation: _, right } => { + check_expr_scoping(&*left, ctx); + check_expr_scoping(&*right, ctx); + }, + Expression::Log2Ceil { value } => { + check_expr_scoping(&*value, ctx); + }, + } +} + +/// Analyzes the simple expression to verify that each variable is defined in the given context. +fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { + match expr { + SimpleExpr::Var(v) => { + assert!(ctx.defines(&v), "Variable defined but not used: {:?}", v) + }, + SimpleExpr::Constant(_) => {}, + SimpleExpr::ConstMallocAccess { .. } => {}, + } +} + +fn check_boolean_scoping(boolean: &Boolean, ctx: &Context) { + match boolean { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + check_expr_scoping(left, ctx); + check_expr_scoping(right, ctx); + }, + } +} + +fn check_condition_scoping(condition: &Condition, ctx: &Context) { + match condition { + Condition::Expression(expr, _) => { + check_expr_scoping(expr, ctx); + }, + Condition::Comparison(boolean) => { + check_boolean_scoping(boolean, ctx); + } + } +} + #[derive(Debug, Clone, Default)] struct Counters { aux_vars: usize, @@ -205,7 +374,6 @@ struct ArrayManager { pub struct ConstMalloc { counter: usize, map: BTreeMap, - forbidden_vars: BTreeSet, // vars shared between branches of an if/else } impl ArrayManager { @@ -231,6 +399,9 @@ fn simplify_lines( let mut res = Vec::new(); for line in lines { match line { + Line::ForwardDeclaration { var } => { + res.push(SimpleLine::ForwardDeclaration { var: var.clone() }); + }, Line::Match { value, arms } => { let simple_value = simplify_expr(value, &mut res, counters, array_manager, const_malloc); let mut simple_arms = vec![]; @@ -386,17 +557,6 @@ fn simplify_lines( } }; - let forbidden_vars_before = const_malloc.forbidden_vars.clone(); - - let then_internal_vars = find_variable_usage(then_branch).0; - let else_internal_vars = find_variable_usage(else_branch).0; - let new_forbidden_vars = then_internal_vars - .intersection(&else_internal_vars) - .cloned() - .collect::>(); - - const_malloc.forbidden_vars.extend(new_forbidden_vars); - let mut array_manager_then = array_manager.clone(); let then_branch_simplified = simplify_lines( then_branch, @@ -418,8 +578,6 @@ fn simplify_lines( const_malloc, ); - const_malloc.forbidden_vars = forbidden_vars_before; - *array_manager = array_manager_else.clone(); // keep the intersection both branches array_manager.valid = array_manager @@ -481,6 +639,8 @@ fn simplify_lines( counter: const_malloc.counter, ..ConstMalloc::default() }; + // TODO: what is array manager, and does it need to be updated + // to make block-level scoping work? let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); array_manager.valid.clear(); let simplified_body = simplify_lines( @@ -612,12 +772,8 @@ fn simplify_lines( let simplified_size = simplify_expr(size, &mut res, counters, array_manager, const_malloc); let simplified_vectorized_len = simplify_expr(vectorized_len, &mut res, counters, array_manager, const_malloc); - if simplified_size.is_constant() && !*vectorized && const_malloc.forbidden_vars.contains(var) { - println!("TODO: Optimization missed: Requires to align const malloc in if/else branches"); - } match simplified_size { - SimpleExpr::Constant(const_size) if !*vectorized && !const_malloc.forbidden_vars.contains(var) => { - // TODO do this optimization even if we are in an if/else branch + SimpleExpr::Constant(const_size) if !*vectorized => { let label = const_malloc.counter; const_malloc.counter += 1; const_malloc.map.insert(var.clone(), label); @@ -638,7 +794,6 @@ fn simplify_lines( } } Line::DecomposeBits { var, to_decompose } => { - assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); let simplified_to_decompose = to_decompose .iter() .map(|expr| simplify_expr(expr, &mut res, counters, array_manager, const_malloc)) @@ -772,6 +927,9 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { for line in lines { match line { + Line::ForwardDeclaration { var } => { + internal_vars.insert(var.clone()); + }, Line::Match { value, arms } => { on_new_expr(value, &internal_vars, &mut external_vars); for (_, statements) in arms { @@ -921,7 +1079,7 @@ pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res let inline_internal_var = |var: &mut Var| { assert!( !args.contains_key(var), - "Variable {var} is both an argument and assigned in the inlined function" + "Variable {var} is both an argument and declared in the inlined function" ); *var = format!("@inlined_var_{inlining_count}_{var}"); }; @@ -929,6 +1087,9 @@ pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res let mut lines_to_replace = vec![]; for (i, line) in lines.iter_mut().enumerate() { match line { + Line::ForwardDeclaration { var } => { + inline_internal_var(var); + } Line::Match { value, arms } => { inline_expr(value, args, inlining_count); for (_, statements) in arms { @@ -1239,6 +1400,9 @@ fn replace_vars_for_unroll( replace_vars_for_unroll(statements, iterator, unroll_index, iterator_value, internal_vars); } } + Line::ForwardDeclaration { var } => { + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + } Line::Assignment { var, value } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); @@ -1699,6 +1863,7 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec) { get_function_called(body, function_called); } Line::Assignment { .. } + | Line::ForwardDeclaration { .. } | Line::ArrayAssign { .. } | Line::Assert { .. } | Line::FunctionRet { .. } @@ -1724,6 +1889,9 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_lines(statements, map); } } + Line::ForwardDeclaration { var } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + } Line::Assignment { var, value } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(value, map); @@ -1831,6 +1999,9 @@ impl SimpleLine { fn to_string_with_indent(&self, indent: usize) -> String { let spaces = " ".repeat(indent); let line_str = match self { + Self::ForwardDeclaration { var } => { + format!("var {var}") + } Self::Match { value, arms } => { let arms_str = arms .iter() @@ -1880,7 +2051,7 @@ impl SimpleLine { ) } Self::RawAccess { res, index, shift } => { - format!("memory[{index} + {shift}] = {res}") + format!("{res} = memory[{index} + {shift}]") } Self::TestZero { operation, arg0, arg1 } => { format!("0 = {arg0} {operation} {arg1}") diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 18cebdf8..d9c62a96 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -14,31 +14,56 @@ struct Compiler { if_counter: usize, call_counter: usize, func_name: String, - var_positions: BTreeMap, // var -> memory offset from fp + stack_frame_layout: StackFrameLayout, args_count: usize, stack_size: usize, + stack_pos: usize, +} + +#[derive(Default)] +struct StackFrameLayout { + // Innermost lexical scope last + scopes: Vec, +} + +#[derive(Default)] +struct ScopeLayout { + var_positions: BTreeMap, // var -> memory offset from fp const_mallocs: BTreeMap, // const_malloc_label -> start = memory offset from fp } impl Compiler { + fn is_in_scope(&self, var: &Var) -> bool { + for scope in self.stack_frame_layout.scopes.iter() { + if let Some(_offset) = scope.var_positions.get(var) { + return true; + } + } + false + } + fn get_offset(&self, var: &VarOrConstMallocAccess) -> ConstExpression { match var { - VarOrConstMallocAccess::Var(var) => (*self - .var_positions - .get(var) - .unwrap_or_else(|| panic!("Variable {var} not in scope"))) - .into(), - VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => ConstExpression::Binary { - left: Box::new( - self.const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, + VarOrConstMallocAccess::Var(var) => { + for scope in self.stack_frame_layout.scopes.iter().rev() { + if let Some(offset) = scope.var_positions.get(var) { + return (*offset).into(); + } + } + panic!("Variable {var} not in scope"); + } + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => { + for scope in self.stack_frame_layout.scopes.iter().rev() { + if let Some(base) = scope.const_mallocs.get(malloc_label) { + return ConstExpression::Binary { + left: Box::new((*base).into()), + operation: HighLevelOperation::Add, + right: Box::new((*offset).clone()), + }; + } + } + panic!("Const malloc {malloc_label} not in scope"); + } } } } @@ -68,18 +93,10 @@ impl IntermediateValue { }, SimpleExpr::Constant(c) => Self::Constant(c.clone()), SimpleExpr::ConstMallocAccess { malloc_label, offset } => Self::MemoryAfterFp { - offset: ConstExpression::Binary { - left: Box::new( - compiler - .const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, + offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *malloc_label, + offset: offset.clone() + }), }, } } @@ -110,28 +127,23 @@ fn compile_function( function: &SimpleFunction, compiler: &mut Compiler, ) -> Result, String> { - let mut internal_vars = find_internal_vars(&function.instructions); - - internal_vars.retain(|var| !function.arguments.contains(var)); - // memory layout: pc, fp, args, return_vars, internal_vars let mut stack_pos = 2; // Reserve space for pc and fp - let mut var_positions = BTreeMap::new(); + let function_scope_layout = ScopeLayout::default(); + compiler.stack_frame_layout = StackFrameLayout { + scopes: vec![function_scope_layout], + }; + let function_scope_layout = &mut compiler.stack_frame_layout.scopes[0]; for (i, var) in function.arguments.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); + function_scope_layout.var_positions.insert(var.clone(), stack_pos + i); } stack_pos += function.arguments.len(); stack_pos += function.n_returned_vars; - for (i, var) in internal_vars.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += internal_vars.len(); - compiler.func_name = function.name.clone(); - compiler.var_positions = var_positions; + compiler.stack_pos = stack_pos; compiler.stack_size = stack_pos; compiler.args_count = function.arguments.len(); @@ -156,23 +168,34 @@ fn compile_lines( for (i, line) in lines.iter().enumerate() { match line { + SimpleLine::ForwardDeclaration { var } => { + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; + } + SimpleLine::Assignment { var, operation, arg0, arg1, } => { - instructions.push(IntermediateInstruction::computation( - *operation, - IntermediateValue::from_simple_expr(arg0, compiler), - IntermediateValue::from_simple_expr(arg1, compiler), - IntermediateValue::from_var_or_const_malloc_access(var, compiler), - )); + let arg0 = IntermediateValue::from_simple_expr(arg0, compiler); + let arg1 = IntermediateValue::from_simple_expr(arg1, compiler); - mark_vars_as_declared(&[arg0, arg1], declared_vars); if let VarOrConstMallocAccess::Var(var) = var { declared_vars.insert(var.clone()); + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; } + + instructions.push(IntermediateInstruction::computation( + *operation, + arg0, + arg1, + IntermediateValue::from_var_or_const_malloc_access(var, compiler), + )); } SimpleLine::TestZero { operation, arg0, arg1 } => { @@ -182,22 +205,22 @@ fn compile_lines( IntermediateValue::from_simple_expr(arg1, compiler), IntermediateValue::Constant(0.into()), )); - - mark_vars_as_declared(&[arg0, arg1], declared_vars); } SimpleLine::Match { value, arms } => { + compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); + let match_index = compiler.match_blocks.len(); let end_label = Label::match_end(match_index); let value_simplified = IntermediateValue::from_simple_expr(value, compiler); let mut compiled_arms = vec![]; - let original_stack_size = compiler.stack_size; - let mut new_stack_size = original_stack_size; + let saved_stack_pos = compiler.stack_pos; + let mut new_stack_pos = saved_stack_pos; for (i, arm) in arms.iter().enumerate() { let mut arm_declared_vars = declared_vars.clone(); - compiler.stack_size = original_stack_size; + compiler.stack_pos = saved_stack_pos; let arm_instructions = compile_lines( function_name, arm, @@ -206,23 +229,23 @@ fn compile_lines( &mut arm_declared_vars, )?; compiled_arms.push(arm_instructions); - new_stack_size = compiler.stack_size.max(new_stack_size); *declared_vars = if i == 0 { arm_declared_vars } else { declared_vars.intersection(&arm_declared_vars).cloned().collect() }; + new_stack_pos = new_stack_pos.max(compiler.stack_pos); } - compiler.stack_size = new_stack_size; + compiler.stack_pos = new_stack_pos; compiler.match_blocks.push(MatchBlock { function_name: function_name.clone(), match_cases: compiled_arms, }); let value_scaled_offset = IntermediateValue::MemoryAfterFp { - offset: compiler.stack_size.into(), + offset: compiler.stack_pos.into(), }; - compiler.stack_size += 1; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Computation { operation: Operation::Mul, arg_a: value_simplified, @@ -231,9 +254,9 @@ fn compile_lines( }); let jump_dest_offset = IntermediateValue::MemoryAfterFp { - offset: compiler.stack_size.into(), + offset: compiler.stack_pos.into(), }; - compiler.stack_size += 1; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: value_scaled_offset, @@ -248,6 +271,11 @@ fn compile_lines( let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump, declared_vars)?; compiler.bytecode.insert(end_label, remaining); + compiler.stack_frame_layout.scopes.pop(); + compiler.stack_pos = saved_stack_pos; + // It is not necessary to update compiler.stack_size here because the preceding call to + // compile lines should have done so. + return Ok(instructions); } @@ -257,6 +285,8 @@ fn compile_lines( else_branch, line_number, } => { + compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); + validate_vars_declared(&[condition], declared_vars)?; let if_id = compiler.if_counter; @@ -272,16 +302,16 @@ fn compile_lines( let condition_simplified = IntermediateValue::from_simple_expr(condition, compiler); // 1/c (or 0 if c is zero) - let condition_inverse_offset = compiler.stack_size; - compiler.stack_size += 1; + let condition_inverse_offset = compiler.stack_pos; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Inverse { arg: condition_simplified.clone(), res_offset: condition_inverse_offset, }); // c x 1/c - let product_offset = compiler.stack_size; - compiler.stack_size += 1; + let product_offset = compiler.stack_pos; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Computation { operation: Operation::Mul, arg_a: condition_simplified.clone(), @@ -292,10 +322,12 @@ fn compile_lines( offset: product_offset.into(), }, }); + // It is not necessary to update compiler.stack_size here because the preceding call to + // compile lines should have done so. // 1 - (c x 1/c) - let one_minus_product_offset = compiler.stack_size; - compiler.stack_size += 1; + let one_minus_product_offset = compiler.stack_pos; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: IntermediateValue::MemoryAfterFp { @@ -329,7 +361,7 @@ fn compile_lines( updated_fp: None, }); - let original_stack = compiler.stack_size; + let saved_stack_pos = compiler.stack_pos; let mut then_declared_vars = declared_vars.clone(); let then_instructions = compile_lines( @@ -339,9 +371,9 @@ fn compile_lines( Some(end_label.clone()), &mut then_declared_vars, )?; - let then_stack = compiler.stack_size; - compiler.stack_size = original_stack; + let then_stack_pos = compiler.stack_pos; + compiler.stack_pos = saved_stack_pos; let mut else_declared_vars = declared_vars.clone(); let else_instructions = compile_lines( function_name, @@ -350,24 +382,29 @@ fn compile_lines( Some(end_label.clone()), &mut else_declared_vars, )?; - let else_stack = compiler.stack_size; - - compiler.stack_size = then_stack.max(else_stack); - *declared_vars = then_declared_vars.intersection(&else_declared_vars).cloned().collect(); compiler.bytecode.insert(if_label, then_instructions); compiler.bytecode.insert(else_label, else_instructions); + compiler.stack_frame_layout.scopes.pop(); + compiler.stack_pos = compiler.stack_pos.max(then_stack_pos); + let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump, declared_vars)?; compiler.bytecode.insert(end_label, remaining); + // It is not necessary to update compiler.stack_size here because the preceding call to + // compile_lines should have done so. return Ok(instructions); } SimpleLine::RawAccess { res, index, shift } => { + // TODO: why is validate_vars_declared here? validate_vars_declared(&[index], declared_vars)?; - if let SimpleExpr::Var(var) = res { + if let SimpleExpr::Var(var) = res && !compiler.is_in_scope(var) { declared_vars.insert(var.clone()); + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; } let shift_0 = match index { SimpleExpr::Constant(c) => c.clone(), @@ -390,8 +427,8 @@ fn compile_lines( compiler.call_counter += 1; let return_label = Label::return_from_call(call_id, *line_number); - let new_fp_pos = compiler.stack_size; - compiler.stack_size += 1; + let new_fp_pos = compiler.stack_pos; + compiler.stack_pos += 1; instructions.extend(setup_function_call( callee_function_name, @@ -401,8 +438,14 @@ fn compile_lines( compiler, )?); + // TODO: why is validate_vars_declared here? validate_vars_declared(args, declared_vars)?; declared_vars.extend(return_data.iter().cloned()); + for var in return_data.iter() { + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; + } let after_call = { let mut instructions = Vec::new(); @@ -430,6 +473,8 @@ fn compile_lines( }; compiler.bytecode.insert(return_label, after_call); + // It is not necessary to update compiler.stack_size here because the preceding call to + // compile_lines should have done so. return Ok(instructions); } @@ -453,9 +498,9 @@ fn compile_lines( if compiler.func_name == "main" { // pc -> ending_pc, fp -> 0 let zero_value_offset = IntermediateValue::MemoryAfterFp { - offset: compiler.stack_size.into(), + offset: compiler.stack_pos.into(), }; - compiler.stack_size += 1; + compiler.stack_pos += 1; instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: IntermediateValue::Constant(0.into()), @@ -477,6 +522,9 @@ fn compile_lines( vectorized, vectorized_len, } => { + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; declared_vars.insert(var.clone()); instructions.push(IntermediateInstruction::RequestMemory { offset: compiler.get_offset(&var.clone().into()), @@ -487,15 +535,26 @@ fn compile_lines( } SimpleLine::ConstMalloc { var, size, label } => { let size = size.naive_eval().unwrap().to_usize(); // TODO not very good; + declared_vars.insert(var.clone()); + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; + current_scope_layout.const_mallocs.insert(*label, compiler.stack_pos); handle_const_malloc(declared_vars, &mut instructions, compiler, var, size, label); + compiler.stack_pos += size; } SimpleLine::DecomposeBits { var, to_decompose, label, } => { + declared_vars.insert(var.clone()); + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(var.clone(), compiler.stack_pos); + compiler.stack_pos += 1; + instructions.push(IntermediateInstruction::DecomposeBits { - res_offset: compiler.stack_size, + res_offset: compiler.stack_pos, to_decompose: to_decompose .iter() .map(|expr| IntermediateValue::from_simple_expr(expr, compiler)) @@ -510,6 +569,7 @@ fn compile_lines( F::bits() * to_decompose.len(), label, ); + compiler.stack_pos += F::bits() * to_decompose.len(); } SimpleLine::DecomposeCustom { args } => { assert!(args.len() >= 3); @@ -524,8 +584,12 @@ fn compile_lines( .collect(), }); } - SimpleLine::PrivateInputStart { result } => { - declared_vars.insert(result.clone()); + SimpleLine::PrivateInputStart { result } => { + if !compiler.is_in_scope(result) { + let mut current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); + current_scope_layout.var_positions.insert(result.clone(), compiler.stack_pos); + compiler.stack_pos += 1; + } instructions.push(IntermediateInstruction::PrivateInputStart { res_offset: compiler.get_offset(&result.clone().into()), }); @@ -545,6 +609,8 @@ fn compile_lines( } } + compiler.stack_size = compiler.stack_size.max(compiler.stack_pos); + if let Some(jump_label) = final_jump { instructions.push(IntermediateInstruction::Jump { dest: IntermediateValue::label(jump_label), @@ -563,28 +629,17 @@ fn handle_const_malloc( size: usize, label: &ConstMallocLabel, ) { - declared_vars.insert(var.clone()); instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, - arg_a: IntermediateValue::Constant(compiler.stack_size.into()), + arg_a: IntermediateValue::Constant(compiler.stack_pos.into()), arg_c: IntermediateValue::Fp, res: IntermediateValue::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), }, }); - compiler.const_mallocs.insert(*label, compiler.stack_size); - compiler.stack_size += size; } // Helper functions -fn mark_vars_as_declared>(vocs: &[VoC], declared: &mut BTreeSet) { - for voc in vocs { - if let SimpleExpr::Var(v) = voc.borrow() { - declared.insert(v.clone()); - } - } -} - fn validate_vars_declared>(vocs: &[VoC], declared: &BTreeSet) -> Result<(), String> { for voc in vocs { if let SimpleExpr::Var(v) = voc.borrow() @@ -664,6 +719,9 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { let mut internal_vars = BTreeSet::new(); for line in lines { match line { + SimpleLine::ForwardDeclaration { var } => { + internal_vars.insert(var.clone()); + } SimpleLine::Match { arms, .. } => { for arm in arms { internal_vars.extend(find_internal_vars(arm)); diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index d6fc7d1a..8de5a975 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -16,6 +16,7 @@ return_count = { "->" ~ number } // Statements statement = { + forward_declaration | single_assignment | array_assign | if_statement | @@ -34,6 +35,8 @@ return_statement = { "return" ~ (tuple_expression)? ~ ";" } break_statement = { "break" ~ ";" } continue_statement = { "continue" ~ ";" } +forward_declaration = { "var" ~ identifier ~ ";" } + single_assignment = { identifier ~ "=" ~ expression ~ ";" } array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index a9db3c76..f2202df5 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -1,7 +1,7 @@ use lean_vm::*; use multilinear_toolkit::prelude::*; use p3_util::log2_ceil_usize; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use utils::ToUsize; @@ -307,6 +307,7 @@ pub enum Line { value: Expression, arms: Vec<(usize, Vec)>, }, + ForwardDeclaration { var: Var }, Assignment { var: Var, value: Expression, @@ -378,6 +379,30 @@ pub enum Line { location: SourceLineNumber, }, } + +/// A context specifying which variables are in scope. +pub struct Context { + /// A list of lexical scopes, innermost scope last. + pub scopes: Vec, +} + +impl Context { + pub fn defines(&self, var: &Var) -> bool { + for scope in self.scopes.iter() { + if scope.vars.contains(var) { + return true; + } + } + false + } +} + +#[derive(Default)] +pub struct Scope { + /// A set of declared variables. + pub vars: BTreeSet, +} + impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -418,6 +443,9 @@ impl Line { .join("\n"); format!("match {value} {{\n{arms_str}\n{spaces}}}") } + Self::ForwardDeclaration { var } => { + format!("var {var}") + } Self::Assignment { var, value } => { format!("{var} = {value}") } diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 750d072b..130d69d6 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -16,9 +16,9 @@ pub fn compile_program(program: String) -> Bytecode { let (parsed_program, function_locations) = parse_program(&program).unwrap(); // println!("Parsed program: {}", parsed_program.to_string()); let simple_program = simplify_program(parsed_program); - // println!("Simplified program: {}", simple_program.to_string()); + println!("Simplified program: {}", simple_program.to_string()); let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program).unwrap(); - // println!("Intermediate Bytecode:\n\n{}", intermediate_bytecode.to_string()); + println!("Intermediate Bytecode:\n\n{}", intermediate_bytecode.to_string()); // println!("Function Locations: \n"); // for (loc, name) in function_locations.iter() { diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index eb13e585..983acf93 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -19,6 +19,7 @@ impl Parse for StatementParser { let inner = next_inner_pair(&mut pair.into_inner(), "statement body")?; match inner.as_rule() { + Rule::forward_declaration => ForwardDeclarationParser::parse(inner, ctx), Rule::single_assignment => AssignmentParser::parse(inner, ctx), Rule::array_assign => ArrayAssignParser::parse(inner, ctx), Rule::if_statement => IfStatementParser::parse(inner, ctx), @@ -35,6 +36,17 @@ impl Parse for StatementParser { } } +/// Parser for forward declarations of variables. +pub struct ForwardDeclarationParser; + +impl Parse for ForwardDeclarationParser { + fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let mut inner = pair.into_inner(); + let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); + Ok(Line::ForwardDeclaration { var }) + } +} + /// Parser for variable assignments. pub struct AssignmentParser; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index cf9edafa..530fb000 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -385,27 +385,11 @@ fn test_inlined() { fn test_match() { let program = r#" fn main() { - for x in 0..3 unroll { - func_match(x); - } - for x in 0..2 unroll { - match x { - 0 => { - y = 10 * (x + 8); - z = 10 * y; - print(z); - } - 1 => { - y = 10 * x; - z = func_2(y); - print(z); - } - } - } + func_match(1); return; } - fn func_match(x) inline { + fn func_match(x) { match x { 0 => { print(41); @@ -415,8 +399,6 @@ fn test_match() { print(y + 1); } 2 => { - y = 10 * x; - print(y); } } return; @@ -425,10 +407,6 @@ fn test_match() { fn func_1(x) -> 1 { return x * x * x * x; } - - fn func_2(x) inline -> 1 { - return x * x * x * x * x * x; - } "#; compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 1659d3c9..c18ab185 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -25,8 +25,8 @@ pub enum RunnerError { #[error("Computation invalid: {0} != {1}")] NotEqual(F, F), - #[error("Undefined memory access")] - UndefinedMemory, + #[error("Undefined memory access: {0}")] + UndefinedMemory(usize), #[error("Program counter out of bounds")] PCOutOfBounds, diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index a162fb46..00a685d9 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -25,7 +25,7 @@ impl Memory { /// /// Returns an error if the address is uninitialized pub fn get(&self, index: usize) -> Result { - self.0.get(index).copied().flatten().ok_or(RunnerError::UndefinedMemory) + self.0.get(index).copied().flatten().ok_or(RunnerError::UndefinedMemory(index)) } /// Sets a value at a memory address diff --git a/crates/lean_vm/src/execution/tests.rs b/crates/lean_vm/src/execution/tests.rs index feaaa596..e42ab656 100644 --- a/crates/lean_vm/src/execution/tests.rs +++ b/crates/lean_vm/src/execution/tests.rs @@ -13,7 +13,7 @@ fn test_basic_memory_operations() { assert_eq!(memory.get(5).unwrap(), F::from_usize(42)); // Test undefined memory access - assert!(matches!(memory.get(1), Err(RunnerError::UndefinedMemory))); + assert!(matches!(memory.get(1), Err(RunnerError::UndefinedMemory(1)))); } #[test] diff --git a/crates/rec_aggregation/recursion_program.lean_lang b/crates/rec_aggregation/recursion_program.lean_lang index e68036f3..37e20e4e 100644 --- a/crates/rec_aggregation/recursion_program.lean_lang +++ b/crates/rec_aggregation/recursion_program.lean_lang @@ -302,15 +302,16 @@ fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefiel fs_states_b = malloc(num_queries + 1); fs_states_b[0] = fs_state_9; + var n_chunks_per_answer; // the number of chunk of 8 field elements per merkle leaf opened if merkle_leaves_in_basefield == 1 { - n_chuncks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield + n_chunks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield } else { - n_chuncks_per_answer = two_pow_folding_factor * DIM / 8; + n_chunks_per_answer = two_pow_folding_factor * DIM / 8; } for i in 0..num_queries { - new_fs_state, answer = fs_hint(fs_states_b[i], n_chuncks_per_answer); + new_fs_state, answer = fs_hint(fs_states_b[i], n_chunks_per_answer); fs_states_b[i + 1] = new_fs_state; answers[i] = answer; } From 9d7beab952cb71a853479dff9ffb035c792314f8 Mon Sep 17 00:00:00 2001 From: Morgan Thomas Date: Tue, 9 Dec 2025 03:06:18 +0000 Subject: [PATCH 02/16] Block-level scoping (#92) and variable manager (#88) --- crates/lean_compiler/src/a_simplify_lang.rs | 8 +- .../src/b_compile_intermediate.rs | 93 ++++++------------- .../src/parser/parsers/statement.rs | 2 +- crates/lean_compiler/tests/test_compiler.rs | 21 +++++ 4 files changed, 57 insertions(+), 67 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index bef4e851..7a39fc8c 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -329,7 +329,7 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) { fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { match expr { SimpleExpr::Var(v) => { - assert!(ctx.defines(&v), "Variable defined but not used: {:?}", v) + assert!(ctx.defines(&v), "Variable used but not defined: {:?}", v) }, SimpleExpr::Constant(_) => {}, SimpleExpr::ConstMallocAccess { .. } => {}, @@ -639,8 +639,6 @@ fn simplify_lines( counter: const_malloc.counter, ..ConstMalloc::default() }; - // TODO: what is array manager, and does it need to be updated - // to make block-level scoping work? let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); array_manager.valid.clear(); let simplified_body = simplify_lines( @@ -1608,6 +1606,10 @@ fn handle_inlined_functions_helper( if let Some(func) = inlined_functions.get(&*function_name) { let mut inlined_lines = vec![]; + for var in return_data.iter() { + inlined_lines.push(Line::ForwardDeclaration { var : var.clone() }); + } + let mut simplified_args = vec![]; for arg in args { if let Expression::Value(simple_expr) = arg { diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index d9c62a96..f51db66f 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -147,13 +147,11 @@ fn compile_function( compiler.stack_size = stack_pos; compiler.args_count = function.arguments.len(); - let mut declared_vars: BTreeSet = function.arguments.iter().cloned().collect(); compile_lines( &Label::function(function.name.clone()), &function.instructions, compiler, None, - &mut declared_vars, ) } @@ -162,14 +160,13 @@ fn compile_lines( lines: &[SimpleLine], compiler: &mut Compiler, final_jump: Option