diff --git a/src/irx/builders/llvmliteir.py b/src/irx/builders/llvmliteir.py index e978b70..4b90660 100644 --- a/src/irx/builders/llvmliteir.py +++ b/src/irx/builders/llvmliteir.py @@ -69,13 +69,24 @@ def splat_scalar( @typechecked -def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function: - """Implement a safe pop operation for lists.""" +def safe_pop( + lst: list[Any], context: str = "" +) -> Optional[ir.Value | ir.Function]: + """Pop from result stack with optional context and type check.""" try: - return lst.pop() + val = lst.pop() except IndexError: + if context: + raise IndexError(f"Popping from an empty stack: {context}") return None + if val is not None and not isinstance(val, (ir.Value, ir.Function)): + raise TypeError( + f"Unexpected stack value type {type(val)!r} at " + f"{context or 'safe_pop'}" + ) + return val + @typechecked class VariablesLLVM: @@ -193,7 +204,6 @@ def initialize(self) -> None: self._init_native_size_types() # initialize the target registry etc. - llvm.initialize() llvm.initialize_all_asmprinters() llvm.initialize_all_targets() llvm.initialize_native_target() @@ -284,7 +294,11 @@ def get_function(self, name: str) -> Optional[ir.Function]: if name in self.function_protos: self.visit(self.function_protos[name]) - return cast(ir.Function, self.result_stack.pop()) + fn = safe_pop(self.result_stack, "get_function") + + if not isinstance(fn, ir.Function): + raise TypeError(f"Expected ir.Function, got {type(fn)!r}") + return fn return None @@ -432,7 +446,7 @@ def visit(self, node: astx.UnaryOp) -> None: """Translate an ASTx UnaryOp expression.""" if node.op_code == "++": self.visit(node.operand) - operand_val = safe_pop(self.result_stack) + operand_val = safe_pop(self.result_stack, "UnaryOp (++) operand") one = ir.Constant(operand_val.type, 1) @@ -450,7 +464,7 @@ def visit(self, node: astx.UnaryOp) -> None: elif node.op_code == "--": self.visit(node.operand) - operand_val = safe_pop(self.result_stack) + operand_val = safe_pop(self.result_stack, "UnaryOp (--) operand") one = ir.Constant(operand_val.type, 1) result = self._llvm.ir_builder.sub(operand_val, one, "dectmp") @@ -464,7 +478,7 @@ def visit(self, node: astx.UnaryOp) -> None: elif node.op_code == "!": self.visit(node.operand) - val = safe_pop(self.result_stack) + val = safe_pop(self.result_stack, "UnaryOp operand") result = self._llvm.ir_builder.xor( val, ir.Constant(val.type, 1), "nottmp" ) @@ -497,7 +511,7 @@ def visit(self, node: astx.BinaryOp) -> None: # Codegen the rhs. self.visit(node.rhs) - llvm_rhs = safe_pop(self.result_stack) + llvm_rhs = safe_pop(self.result_stack, "BinaryOp (= rhs)") if not llvm_rhs: raise Exception("codegen: Invalid rhs expression.") @@ -513,10 +527,10 @@ def visit(self, node: astx.BinaryOp) -> None: return self.visit(node.lhs) - llvm_lhs = safe_pop(self.result_stack) + llvm_lhs = safe_pop(self.result_stack, "BinaryOp (lhs)") self.visit(node.rhs) - llvm_rhs = safe_pop(self.result_stack) + llvm_rhs = safe_pop(self.result_stack, "BinaryOp (rhs)") if not llvm_lhs or not llvm_rhs: raise Exception("codegen: Invalid lhs/rhs") @@ -593,7 +607,9 @@ def visit(self, node: astx.BinaryOp) -> None: if not hasattr(node, "fma_rhs"): raise Exception("FMA requires a third operand (fma_rhs)") self.visit(node.fma_rhs) - llvm_fma_rhs = safe_pop(self.result_stack) + llvm_fma_rhs = safe_pop( + self.result_stack, "BinaryOp (fma_rhs)" + ) if llvm_fma_rhs.type != llvm_lhs.type: raise Exception( f"FMA operand type mismatch: " @@ -845,11 +861,7 @@ def visit(self, block: astx.Block) -> None: result = None for node in block.nodes: self.visit(node) - try: - result = self.result_stack.pop() - except IndexError: - # some nodes doesn't add anything in the stack - pass + result = safe_pop(self.result_stack) if result is not None: self.result_stack.append(result) @@ -857,7 +869,7 @@ def visit(self, block: astx.Block) -> None: def visit(self, node: astx.IfStmt) -> None: """Translate IF statement.""" self.visit(node.condition) - cond_v = self.result_stack.pop() + cond_v = safe_pop(self.result_stack, "IfStmt condition") if not cond_v: raise Exception("codegen: Invalid condition expression.") @@ -890,7 +902,7 @@ def visit(self, node: astx.IfStmt) -> None: # Emit then value. self._llvm.ir_builder.position_at_start(then_bb) self.visit(node.then) - then_v = self.result_stack.pop() + then_v = safe_pop(self.result_stack, "IfStmt then block") if not then_v: raise Exception("codegen: `Then` expression is invalid.") @@ -904,7 +916,7 @@ def visit(self, node: astx.IfStmt) -> None: else_v = None if node.else_ is not None: self.visit(node.else_) - else_v = self.result_stack.pop() + else_v = safe_pop(self.result_stack, "IfStmt else block") else: else_v = ir.Constant(self._llvm.INT32_TYPE, 0) @@ -943,7 +955,7 @@ def visit(self, expr: astx.WhileStmt) -> None: # Emit the condition. self.visit(expr.condition) - cond_val = self.result_stack.pop() + cond_val = safe_pop(self.result_stack, "WhileStmt condition") if not cond_val: raise Exception("codegen: Invalid condition expression.") @@ -970,7 +982,7 @@ def visit(self, expr: astx.WhileStmt) -> None: # Emit the body of the loop. self.visit(expr.body) - body_val = self.result_stack.pop() + body_val = safe_pop(self.result_stack, "WhileStmt body") if not body_val: return @@ -993,7 +1005,7 @@ def visit(self, expr: astx.VariableAssignment) -> None: # Codegen the value expression on the right-hand side self.visit(expr.value) - llvm_value = safe_pop(self.result_stack) + llvm_value = safe_pop(self.result_stack, "VariableAssignment (value)") if not llvm_value: raise Exception("codegen: Invalid value in VariableAssignment.") @@ -1023,7 +1035,9 @@ def visit(self, node: astx.ForCountLoopStmt) -> None: # Emit the start code first, without 'variable' in scope. self.visit(node.initializer) - initializer_val = self.result_stack.pop() + initializer_val = safe_pop( + self.result_stack, "ForCountLoop initializer" + ) if not initializer_val: raise Exception("codegen: Invalid start argument.") @@ -1044,7 +1058,7 @@ def visit(self, node: astx.ForCountLoopStmt) -> None: # Emit condition check (e.g., i < 10) self.visit(node.condition) - cond_val = self.result_stack.pop() + cond_val = safe_pop(self.result_stack, "ForCountLoop condition") # Create blocks for loop body and after loop loop_body_bb = self._llvm.ir_builder.function.append_basic_block( @@ -1060,11 +1074,11 @@ def visit(self, node: astx.ForCountLoopStmt) -> None: # Emit loop body self._llvm.ir_builder.position_at_start(loop_body_bb) self.visit(node.body) - _body_val = self.result_stack.pop() + _body_val = safe_pop(self.result_stack, "ForCountLoop body") # Emit update expression self.visit(node.update) - update_val = self.result_stack.pop() + update_val = safe_pop(self.result_stack, "ForCountLoop update") # Store updated value self._llvm.ir_builder.store(update_val, var_addr) @@ -1100,7 +1114,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None: # Emit the start code first, without 'variable' in scope. self.visit(node.start) - start_val = self.result_stack.pop() + start_val = safe_pop(self.result_stack, "ForRangeLoop start") if not start_val: raise Exception("codegen: Invalid start argument.") self._llvm.ir_builder.store(start_val, var_addr) @@ -1116,7 +1130,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None: # Emit the body of the loop. self.visit(node.body) - body_val = self.result_stack.pop() + body_val = safe_pop(self.result_stack, "ForRangeLoop body") if not body_val: return @@ -1124,7 +1138,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None: # Emit the step value. if node.step: self.visit(node.step) - step_val = self.result_stack.pop() + step_val = safe_pop(self.result_stack, "ForRangeLoop step") if not step_val: return else: @@ -1137,7 +1151,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None: # Compute the end condition. self.visit(node.end) - end_cond = self.result_stack.pop() + end_cond = safe_pop(self.result_stack, "ForRangeLoop end") if not end_cond: return @@ -1544,7 +1558,7 @@ def visit(self, node: astx.LiteralList) -> None: llvm_elems: list[ir.Value] = [] for elem in node.elements: self.visit(elem) - v = self.result_stack.pop() + v = safe_pop(self.result_stack, "LiteralList element") if v is None: raise Exception("LiteralList: invalid element lowering.") llvm_elems.append(v) @@ -1875,7 +1889,7 @@ def visit(self, node: astx.FunctionCall) -> None: llvm_args = [] for arg in node.args: self.visit(arg) - llvm_arg = self.result_stack.pop() + llvm_arg = safe_pop(self.result_stack, "FunctionCall argument") if not llvm_arg: raise Exception("codegen: Invalid callee argument.") llvm_args.append(llvm_arg) @@ -1940,10 +1954,7 @@ def visit(self, node: astx.FunctionReturn) -> None: """Translate ASTx FunctionReturn to LLVM-IR.""" if node.value is not None: self.visit(node.value) - try: - retval = self.result_stack.pop() - except IndexError: - retval = None + retval = safe_pop(self.result_stack, "FunctionReturn value") else: retval = None @@ -1976,7 +1987,9 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: # Emit the initializer if node.value is not None: self.visit(node.value) - init_val = self.result_stack.pop() + init_val = safe_pop( + self.result_stack, "InlineVariableDeclaration initializer" + ) if init_val is None: raise Exception("Initializer code generation failed.") # Default zero value based on type @@ -2110,7 +2123,7 @@ def _get_or_create_format_global(self, fmt: str) -> ir.GlobalVariable: def visit(self, node: system.Cast) -> None: """Translate Cast expression to LLVM-IR.""" self.visit(node.value) - value = self.result_stack.pop() + value = safe_pop(self.result_stack, "Cast value") target_type_str = node.target_type.__class__.__name__.lower() target_type = self._llvm.get_data_type(target_type_str) @@ -2234,7 +2247,7 @@ def visit(self, node: system.PrintExpr) -> None: else: # For variables and other expressions self.visit(node.message) - ptr = safe_pop(self.result_stack) + ptr = safe_pop(self.result_stack, "PrintExpr (message)") if not ptr: raise Exception("Invalid message in PrintExpr") @@ -2271,7 +2284,9 @@ def visit(self, node: astx.VariableDeclaration) -> None: # Emit the initializer if node.value is not None: self.visit(node.value) - init_val = self.result_stack.pop() + init_val = safe_pop( + self.result_stack, "VariableDeclaration initializer" + ) if init_val is None: raise Exception("Initializer code generation failed.")