Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 56 additions & 41 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,24 @@ def splat_scalar(


@typechecked
def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function:
"""Implement a safe pop operation for lists."""
def safe_pop(
lst: list[Any], context: str = ""
) -> Optional[ir.Value | ir.Function]:
"""Pop from result stack with optional context and type check."""
try:
return lst.pop()
val = lst.pop()
except IndexError:
if context:
raise IndexError(f"Popping from an empty stack: {context}")
return None

if val is not None and not isinstance(val, (ir.Value, ir.Function)):
raise TypeError(
f"Unexpected stack value type {type(val)!r} at "
f"{context or 'safe_pop'}"
)
return val


@typechecked
class VariablesLLVM:
Expand Down Expand Up @@ -193,7 +204,6 @@ def initialize(self) -> None:
self._init_native_size_types()

# initialize the target registry etc.
llvm.initialize()
llvm.initialize_all_asmprinters()
llvm.initialize_all_targets()
llvm.initialize_native_target()
Expand Down Expand Up @@ -284,7 +294,11 @@ def get_function(self, name: str) -> Optional[ir.Function]:

if name in self.function_protos:
self.visit(self.function_protos[name])
return cast(ir.Function, self.result_stack.pop())
fn = safe_pop(self.result_stack, "get_function")

if not isinstance(fn, ir.Function):
raise TypeError(f"Expected ir.Function, got {type(fn)!r}")
return fn

return None

Expand Down Expand Up @@ -432,7 +446,7 @@ def visit(self, node: astx.UnaryOp) -> None:
"""Translate an ASTx UnaryOp expression."""
if node.op_code == "++":
self.visit(node.operand)
operand_val = safe_pop(self.result_stack)
operand_val = safe_pop(self.result_stack, "UnaryOp (++) operand")

one = ir.Constant(operand_val.type, 1)

Expand All @@ -450,7 +464,7 @@ def visit(self, node: astx.UnaryOp) -> None:

elif node.op_code == "--":
self.visit(node.operand)
operand_val = safe_pop(self.result_stack)
operand_val = safe_pop(self.result_stack, "UnaryOp (--) operand")
one = ir.Constant(operand_val.type, 1)
result = self._llvm.ir_builder.sub(operand_val, one, "dectmp")

Expand All @@ -464,7 +478,7 @@ def visit(self, node: astx.UnaryOp) -> None:

elif node.op_code == "!":
self.visit(node.operand)
val = safe_pop(self.result_stack)
val = safe_pop(self.result_stack, "UnaryOp operand")
result = self._llvm.ir_builder.xor(
val, ir.Constant(val.type, 1), "nottmp"
)
Expand Down Expand Up @@ -497,7 +511,7 @@ def visit(self, node: astx.BinaryOp) -> None:

# Codegen the rhs.
self.visit(node.rhs)
llvm_rhs = safe_pop(self.result_stack)
llvm_rhs = safe_pop(self.result_stack, "BinaryOp (= rhs)")

if not llvm_rhs:
raise Exception("codegen: Invalid rhs expression.")
Expand All @@ -513,10 +527,10 @@ def visit(self, node: astx.BinaryOp) -> None:
return

self.visit(node.lhs)
llvm_lhs = safe_pop(self.result_stack)
llvm_lhs = safe_pop(self.result_stack, "BinaryOp (lhs)")

self.visit(node.rhs)
llvm_rhs = safe_pop(self.result_stack)
llvm_rhs = safe_pop(self.result_stack, "BinaryOp (rhs)")

if not llvm_lhs or not llvm_rhs:
raise Exception("codegen: Invalid lhs/rhs")
Expand Down Expand Up @@ -593,7 +607,9 @@ def visit(self, node: astx.BinaryOp) -> None:
if not hasattr(node, "fma_rhs"):
raise Exception("FMA requires a third operand (fma_rhs)")
self.visit(node.fma_rhs)
llvm_fma_rhs = safe_pop(self.result_stack)
llvm_fma_rhs = safe_pop(
self.result_stack, "BinaryOp (fma_rhs)"
)
if llvm_fma_rhs.type != llvm_lhs.type:
raise Exception(
f"FMA operand type mismatch: "
Expand Down Expand Up @@ -845,19 +861,15 @@ def visit(self, block: astx.Block) -> None:
result = None
for node in block.nodes:
self.visit(node)
try:
result = self.result_stack.pop()
except IndexError:
# some nodes doesn't add anything in the stack
pass
result = safe_pop(self.result_stack)
if result is not None:
self.result_stack.append(result)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.IfStmt) -> None:
"""Translate IF statement."""
self.visit(node.condition)
cond_v = self.result_stack.pop()
cond_v = safe_pop(self.result_stack, "IfStmt condition")
if not cond_v:
raise Exception("codegen: Invalid condition expression.")

Expand Down Expand Up @@ -890,7 +902,7 @@ def visit(self, node: astx.IfStmt) -> None:
# Emit then value.
self._llvm.ir_builder.position_at_start(then_bb)
self.visit(node.then)
then_v = self.result_stack.pop()
then_v = safe_pop(self.result_stack, "IfStmt then block")
if not then_v:
raise Exception("codegen: `Then` expression is invalid.")

Expand All @@ -904,7 +916,7 @@ def visit(self, node: astx.IfStmt) -> None:
else_v = None
if node.else_ is not None:
self.visit(node.else_)
else_v = self.result_stack.pop()
else_v = safe_pop(self.result_stack, "IfStmt else block")
else:
else_v = ir.Constant(self._llvm.INT32_TYPE, 0)

Expand Down Expand Up @@ -943,7 +955,7 @@ def visit(self, expr: astx.WhileStmt) -> None:

# Emit the condition.
self.visit(expr.condition)
cond_val = self.result_stack.pop()
cond_val = safe_pop(self.result_stack, "WhileStmt condition")
if not cond_val:
raise Exception("codegen: Invalid condition expression.")

Expand All @@ -970,7 +982,7 @@ def visit(self, expr: astx.WhileStmt) -> None:

# Emit the body of the loop.
self.visit(expr.body)
body_val = self.result_stack.pop()
body_val = safe_pop(self.result_stack, "WhileStmt body")

if not body_val:
return
Expand All @@ -993,7 +1005,7 @@ def visit(self, expr: astx.VariableAssignment) -> None:

# Codegen the value expression on the right-hand side
self.visit(expr.value)
llvm_value = safe_pop(self.result_stack)
llvm_value = safe_pop(self.result_stack, "VariableAssignment (value)")

if not llvm_value:
raise Exception("codegen: Invalid value in VariableAssignment.")
Expand Down Expand Up @@ -1023,7 +1035,9 @@ def visit(self, node: astx.ForCountLoopStmt) -> None:

# Emit the start code first, without 'variable' in scope.
self.visit(node.initializer)
initializer_val = self.result_stack.pop()
initializer_val = safe_pop(
self.result_stack, "ForCountLoop initializer"
)
if not initializer_val:
raise Exception("codegen: Invalid start argument.")

Expand All @@ -1044,7 +1058,7 @@ def visit(self, node: astx.ForCountLoopStmt) -> None:

# Emit condition check (e.g., i < 10)
self.visit(node.condition)
cond_val = self.result_stack.pop()
cond_val = safe_pop(self.result_stack, "ForCountLoop condition")

# Create blocks for loop body and after loop
loop_body_bb = self._llvm.ir_builder.function.append_basic_block(
Expand All @@ -1060,11 +1074,11 @@ def visit(self, node: astx.ForCountLoopStmt) -> None:
# Emit loop body
self._llvm.ir_builder.position_at_start(loop_body_bb)
self.visit(node.body)
_body_val = self.result_stack.pop()
_body_val = safe_pop(self.result_stack, "ForCountLoop body")

# Emit update expression
self.visit(node.update)
update_val = self.result_stack.pop()
update_val = safe_pop(self.result_stack, "ForCountLoop update")

# Store updated value
self._llvm.ir_builder.store(update_val, var_addr)
Expand Down Expand Up @@ -1100,7 +1114,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None:

# Emit the start code first, without 'variable' in scope.
self.visit(node.start)
start_val = self.result_stack.pop()
start_val = safe_pop(self.result_stack, "ForRangeLoop start")
if not start_val:
raise Exception("codegen: Invalid start argument.")
self._llvm.ir_builder.store(start_val, var_addr)
Expand All @@ -1116,15 +1130,15 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None:

# Emit the body of the loop.
self.visit(node.body)
body_val = self.result_stack.pop()
body_val = safe_pop(self.result_stack, "ForRangeLoop body")

if not body_val:
return

# Emit the step value.
if node.step:
self.visit(node.step)
step_val = self.result_stack.pop()
step_val = safe_pop(self.result_stack, "ForRangeLoop step")
if not step_val:
return
else:
Expand All @@ -1137,7 +1151,7 @@ def visit(self, node: astx.ForRangeLoopStmt) -> None:

# Compute the end condition.
self.visit(node.end)
end_cond = self.result_stack.pop()
end_cond = safe_pop(self.result_stack, "ForRangeLoop end")
if not end_cond:
return

Expand Down Expand Up @@ -1544,7 +1558,7 @@ def visit(self, node: astx.LiteralList) -> None:
llvm_elems: list[ir.Value] = []
for elem in node.elements:
self.visit(elem)
v = self.result_stack.pop()
v = safe_pop(self.result_stack, "LiteralList element")
if v is None:
raise Exception("LiteralList: invalid element lowering.")
llvm_elems.append(v)
Expand Down Expand Up @@ -1875,7 +1889,7 @@ def visit(self, node: astx.FunctionCall) -> None:
llvm_args = []
for arg in node.args:
self.visit(arg)
llvm_arg = self.result_stack.pop()
llvm_arg = safe_pop(self.result_stack, "FunctionCall argument")
if not llvm_arg:
raise Exception("codegen: Invalid callee argument.")
llvm_args.append(llvm_arg)
Expand Down Expand Up @@ -1940,10 +1954,7 @@ def visit(self, node: astx.FunctionReturn) -> None:
"""Translate ASTx FunctionReturn to LLVM-IR."""
if node.value is not None:
self.visit(node.value)
try:
retval = self.result_stack.pop()
except IndexError:
retval = None
retval = safe_pop(self.result_stack, "FunctionReturn value")
else:
retval = None

Expand Down Expand Up @@ -1976,7 +1987,9 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None:
# Emit the initializer
if node.value is not None:
self.visit(node.value)
init_val = self.result_stack.pop()
init_val = safe_pop(
self.result_stack, "InlineVariableDeclaration initializer"
)
if init_val is None:
raise Exception("Initializer code generation failed.")
# Default zero value based on type
Expand Down Expand Up @@ -2110,7 +2123,7 @@ def _get_or_create_format_global(self, fmt: str) -> ir.GlobalVariable:
def visit(self, node: system.Cast) -> None:
"""Translate Cast expression to LLVM-IR."""
self.visit(node.value)
value = self.result_stack.pop()
value = safe_pop(self.result_stack, "Cast value")
target_type_str = node.target_type.__class__.__name__.lower()
target_type = self._llvm.get_data_type(target_type_str)

Expand Down Expand Up @@ -2234,7 +2247,7 @@ def visit(self, node: system.PrintExpr) -> None:
else:
# For variables and other expressions
self.visit(node.message)
ptr = safe_pop(self.result_stack)
ptr = safe_pop(self.result_stack, "PrintExpr (message)")
if not ptr:
raise Exception("Invalid message in PrintExpr")

Expand Down Expand Up @@ -2271,7 +2284,9 @@ def visit(self, node: astx.VariableDeclaration) -> None:
# Emit the initializer
if node.value is not None:
self.visit(node.value)
init_val = self.result_stack.pop()
init_val = safe_pop(
self.result_stack, "VariableDeclaration initializer"
)
if init_val is None:
raise Exception("Initializer code generation failed.")

Expand Down
Loading