Skip to content

Commit

Permalink
Constant propagation and folding.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikebenfield committed Jan 24, 2025
1 parent 72a9acc commit 89ca990
Show file tree
Hide file tree
Showing 280 changed files with 3,407 additions and 1,284 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion compiler/ast/src/common/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use std::{
/// Attention - When adding or removing fields from this struct,
/// please remember to update its Serialize and Deserialize implementation
/// to reflect the new struct instantiation.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Default)]
pub struct Identifier {
/// The symbol that the user wrote, e.g., `foo`.
pub name: Symbol,
Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/expressions/err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use super::*;

/// Represents a syntactically invalid expression.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ErrExpression {
/// The span of the invalid expression.
pub span: Span,
Expand Down
6 changes: 6 additions & 0 deletions compiler/ast/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ pub enum Expression {
Unit(UnitExpression),
}

impl Default for Expression {
fn default() -> Self {
Expression::Err(Default::default())
}
}

impl Node for Expression {
fn span(&self) -> Span {
use Expression::*;
Expand Down
10 changes: 8 additions & 2 deletions compiler/ast/src/expressions/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl UnaryOperation {
})
}

/// Represents the opera.tor as a string.
/// Represents the operator as a string.
fn as_str(self) -> &'static str {
match self {
Self::Abs => "abs",
Expand All @@ -77,6 +77,12 @@ impl UnaryOperation {
}
}

impl fmt::Display for UnaryOperation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}

/// An unary expression applying an operator to an inner expression.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UnaryExpression {
Expand All @@ -92,7 +98,7 @@ pub struct UnaryExpression {

impl fmt::Display for UnaryExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}{}", self.op.as_str(), self.receiver)
write!(f, "({}).{}()", self.receiver, self.op)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use serde::{Deserialize, Serialize};
use std::fmt;

/// A function definition.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct Function {
/// Annotations on the function.
pub annotations: Vec<Annotation>,
Expand Down
3 changes: 2 additions & 1 deletion compiler/ast/src/functions/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use serde::{Deserialize, Serialize};
/// A regular function is not permitted to manipulate records.
/// An asynchronous function contains on-chain operations.
/// An inline function is directly copied at the call site.
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum Variant {
#[default]
Inline,
Function,
Transition,
Expand Down
3 changes: 3 additions & 0 deletions compiler/ast/src/program/program_scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ impl From<Stub> for ProgramScope {
impl fmt::Display for ProgramScope {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "program {} {{", self.program_id)?;
for (_, const_decl) in self.consts.iter() {
writeln!(f, " {const_decl}")?;
}
for (_, struct_) in self.structs.iter() {
writeln!(f, " {struct_}")?;
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/statement/assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct AssignStatement {

impl fmt::Display for AssignStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} = {};", self.place, self.value)
write!(f, "{} = {}", self.place, self.value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/statement/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl fmt::Display for Block {
if self.statements.is_empty() {
writeln!(f, "\t")?;
} else {
self.statements.iter().try_for_each(|statement| writeln!(f, "\t{statement}"))?;
self.statements.iter().try_for_each(|statement| writeln!(f, "\t{statement};"))?;
}
write!(f, "}}")
}
Expand Down
6 changes: 2 additions & 4 deletions compiler/ast/src/statement/const_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize};
use std::fmt;

/// A constant declaration statement.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
#[derive(Clone, Default, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub struct ConstDeclaration {
/// The place to assign to. As opposed to `DefinitionStatement`, this can only be an identifier
pub place: Identifier,
Expand All @@ -37,9 +37,7 @@ pub struct ConstDeclaration {

impl fmt::Display for ConstDeclaration {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.place)?;
write!(f, ": {}", self.type_)?;
write!(f, " = {};", self.value)
write!(f, "const {}: {} = {}", self.place, self.type_, self.value)
}
}

Expand Down
5 changes: 1 addition & 4 deletions compiler/ast/src/statement/definition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ pub struct DefinitionStatement {

impl fmt::Display for DefinitionStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} ", self.declaration_type)?;
write!(f, "{}", self.place)?;
write!(f, ": {}", self.type_)?;
write!(f, " = {};", self.value)
write!(f, "{} {}: {} = {}", self.declaration_type, self.place, self.type_, self.value)
}
}

Expand Down
69 changes: 55 additions & 14 deletions compiler/compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ impl<'a, N: Network> Compiler<'a, N> {
}

/// Runs the type checker pass.
pub fn type_checker_pass(&'a self, symbol_table: SymbolTable) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let (symbol_table, struct_graph, call_graph) =
pub fn type_checker_pass(&'a self, symbol_table: &mut SymbolTable) -> Result<(StructGraph, CallGraph)> {
let (struct_graph, call_graph) =
TypeChecker::do_pass((&self.ast, self.handler, symbol_table, &self.type_table, NetworkLimits {
max_array_elements: N::MAX_ARRAY_ELEMENTS,
max_mappings: N::MAX_MAPPINGS,
max_functions: N::MAX_FUNCTIONS,
}))?;
Ok((symbol_table, struct_graph, call_graph))
Ok((struct_graph, call_graph))
}

/// Runs the static analysis pass.
Expand All @@ -172,22 +172,64 @@ impl<'a, N: Network> Compiler<'a, N> {
))
}

/// Run const propagation and loop unrolling until we hit a fixed point or find an error.
pub fn const_propagation_and_unroll_loop(&mut self, symbol_table: &mut SymbolTable) -> Result<()> {
loop {
let loop_unroll_output = self.loop_unrolling_pass(symbol_table)?;

let const_prop_output = self.const_propagation_pass(symbol_table)?;

if !const_prop_output.changed && !loop_unroll_output.loop_unrolled {
// We've got a fixed point, so see if we have any errors.
if let Some(not_evaluated_span) = const_prop_output.const_not_evaluated {
return Err(CompilerError::const_not_evaluated(not_evaluated_span).into());
}

if let Some(not_evaluated_span) = const_prop_output.array_index_not_evaluated {
return Err(CompilerError::array_index_not_evaluated(not_evaluated_span).into());
}

if let Some(not_unrolled_span) = loop_unroll_output.loop_not_unrolled {
return Err(CompilerError::loop_bounds_not_evaluated(not_unrolled_span).into());
}

if self.compiler_options.output.unrolled_ast {
self.write_ast_to_json("unrolled_ast.json")?;
}

return Ok(());
}
}
}

/// Runs the const propagation pass.
pub fn const_propagation_pass(&mut self, symbol_table: &mut SymbolTable) -> Result<ConstPropagatorOutput> {
let (ast, output) = ConstPropagator::do_pass((
std::mem::take(&mut self.ast),
self.handler,
symbol_table,
&self.type_table,
&self.node_builder,
))?;

self.ast = ast;

Ok(output)
}

/// Runs the loop unrolling pass.
pub fn loop_unrolling_pass(&mut self, symbol_table: SymbolTable) -> Result<SymbolTable> {
let (ast, symbol_table) = Unroller::do_pass((
pub fn loop_unrolling_pass(&mut self, symbol_table: &mut SymbolTable) -> Result<UnrollerOutput> {
let (ast, output) = Unroller::do_pass((
std::mem::take(&mut self.ast),
self.handler,
&self.node_builder,
symbol_table,
&self.type_table,
))?;
self.ast = ast;

if self.compiler_options.output.unrolled_ast {
self.write_ast_to_json("unrolled_ast.json")?;
}
self.ast = ast;

Ok(symbol_table)
Ok(output)
}

/// Runs the static single assignment pass.
Expand Down Expand Up @@ -283,14 +325,13 @@ impl<'a, N: Network> Compiler<'a, N> {

/// Runs the compiler stages.
pub fn compiler_stages(&mut self) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let st = self.symbol_table_pass()?;
let mut st = self.symbol_table_pass()?;

let (st, struct_graph, call_graph) = self.type_checker_pass(st)?;
let (struct_graph, call_graph) = self.type_checker_pass(&mut st)?;

self.static_analysis_pass(&st)?;

// TODO: Make this pass optional.
let st = self.loop_unrolling_pass(st)?;
self.const_propagation_and_unroll_loop(&mut st)?;

self.static_single_assignment_pass(&st)?;

Expand Down
6 changes: 3 additions & 3 deletions compiler/compiler/tests/integration/utilities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,17 @@ pub fn temp_dir() -> PathBuf {
pub fn compile_and_process<'a>(parsed: &'a mut Compiler<'a, CurrentNetwork>) -> Result<String, LeoError> {
parsed.add_import_stubs()?;

let st = parsed.symbol_table_pass()?;
let mut st = parsed.symbol_table_pass()?;

CheckUniqueNodeIds::new().visit_program(&parsed.ast.ast);

let (st, struct_graph, call_graph) = parsed.type_checker_pass(st)?;
let (struct_graph, call_graph) = parsed.type_checker_pass(&mut st)?;

parsed.static_analysis_pass(&st)?;

CheckUniqueNodeIds::new().visit_program(&parsed.ast.ast);

let st = parsed.loop_unrolling_pass(st)?;
parsed.const_propagation_and_unroll_loop(&mut st)?;

parsed.static_single_assignment_pass(&st)?;

Expand Down
3 changes: 3 additions & 0 deletions compiler/passes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ workspace = true
[dependencies.leo-errors]
workspace = true

[dependencies.leo-interpreter]
workspace = true

[dependencies.leo-span]
workspace = true

Expand Down
16 changes: 16 additions & 0 deletions compiler/passes/src/common/symbol_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ impl LocalTable {
})),
}
}

fn dup(&self, new_id: NodeID) -> Self {
let mut inner = self.inner.borrow().clone();
inner.id = new_id;
LocalTable { inner: Rc::new(RefCell::new(inner)) }
}
}

impl SymbolTable {
Expand Down Expand Up @@ -139,10 +145,20 @@ impl SymbolTable {
self.local = id.map(|id| {
let parent = self.local.as_ref().map(|table| table.inner.borrow().id);
let new_local_table = self.all_locals.entry(id).or_insert_with(|| LocalTable::new(id, parent));
assert_eq!(parent, new_local_table.inner.borrow().parent, "Entered scopes out of order.");
new_local_table.clone()
});
}

pub fn enter_scope_duped(&mut self, new_id: NodeID, old_id: NodeID) {
let old_local_table = self.all_locals.get(&old_id).expect("Must have an old scope to dup from.");
let new_local_table = old_local_table.dup(new_id);
let parent = self.local.as_ref().map(|table| table.inner.borrow().id);
new_local_table.inner.borrow_mut().parent = parent;
self.all_locals.insert(new_id, new_local_table.clone());
self.local = Some(new_local_table);
}

/// Enther the parent scope of the current scope (or the global scope if there is no local parent scope).
pub fn enter_parent(&mut self) {
let parent: Option<NodeID> = self.local.as_ref().and_then(|table| table.inner.borrow().parent);
Expand Down
Loading

0 comments on commit 89ca990

Please sign in to comment.