diff --git a/src/irx/builders/llvmliteir.py b/src/irx/builders/llvmliteir.py index e978b70..247cff1 100644 --- a/src/irx/builders/llvmliteir.py +++ b/src/irx/builders/llvmliteir.py @@ -155,6 +155,7 @@ class LLVMLiteIRVisitor(BuilderVisitor): _llvm: VariablesLLVM function_protos: dict[str, astx.FunctionPrototype] + struct_types: dict[str, ir.IdentifiedStructType] result_stack: list[ir.Value | ir.Function] = [] def __init__(self) -> None: @@ -164,6 +165,7 @@ def __init__(self) -> None: # named_values as instance variable so it isn't shared across instances self.named_values: dict[str, Any] = {} self.function_protos: dict[str, astx.FunctionPrototype] = {} + self.struct_types: dict[str, ir.IdentifiedStructType] = {} self.result_stack: list[ir.Value | ir.Function] = [] self.initialize() @@ -2266,6 +2268,35 @@ def visit(self, node: astx.VariableDeclaration) -> None: if self.named_values.get(node.name): raise Exception(f"Identifier already declared: {node.name}") + if isinstance(node.type_, system.StructType): + struct_name = node.type_.struct_name + if struct_name not in self.struct_types: + raise Exception(f"Struct '{struct_name}' not defined.") + + llvm_struct_type = self.struct_types[struct_name] + self._llvm.ir_builder.position_at_start( + self._llvm.ir_builder.function.entry_basic_block + ) + alloca = self._llvm.ir_builder.alloca( + llvm_struct_type, None, node.name + ) + self._llvm.ir_builder.position_at_end(self._llvm.ir_builder.block) + + if node.value is not None: + raise Exception( + "Struct initialization with values not yet supported." + ) + + zero_fields = [ + ir.Constant(field_type, 0) + for field_type in llvm_struct_type.elements + ] + zero_init = ir.Constant(llvm_struct_type, zero_fields) + self._llvm.ir_builder.store(zero_init, alloca) + + self.named_values[node.name] = alloca + return + type_str = node.type_.__class__.__name__.lower() # Emit the initializer diff --git a/src/irx/system.py b/src/irx/system.py index d96272a..d2c5de1 100644 --- a/src/irx/system.py +++ b/src/irx/system.py @@ -47,3 +47,10 @@ def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: key = f"Cast[{self.target_type}]" value = self.value.get_struct(simplified) return self._prepare_struct(key, value, simplified) + + +class StructType(astx.DataType): + """Type reference for previously defined structs.""" + + def __init__(self, struct_name: str) -> None: + self.struct_name = struct_name diff --git a/tests/test_struct_type.py b/tests/test_struct_type.py new file mode 100644 index 0000000..fd7c280 --- /dev/null +++ b/tests/test_struct_type.py @@ -0,0 +1,113 @@ +"""Tests for StructType variable declarations.""" + +from typing import Type + +import astx +import pytest + +from irx.builders.base import Builder +from irx.builders.llvmliteir import LLVMLiteIR +from irx.system import StructType + +from .conftest import check_result + + +@pytest.mark.parametrize("builder_class", [LLVMLiteIR]) +def test_struct_variable_declaration( + builder_class: Type[Builder], +) -> None: + """Test struct variable declaration with defined struct.""" + builder = builder_class() + module = builder.module() + + struct_def = astx.StructDefStmt( + name="Point", + attributes=[ + astx.VariableDeclaration("x", astx.Int32()), + astx.VariableDeclaration("y", astx.Int32()), + ], + ) + + struct_var = astx.VariableDeclaration( + name="p", + type_=StructType(struct_name="Point"), + ) + + block = astx.Block() + block.append(struct_def) + block.append(struct_var) + block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + + proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Int32() + ) + fn = astx.FunctionDef(prototype=proto, body=block) + module.block.append(fn) + + check_result("build", builder, module, expected_output="0") + + +@pytest.mark.parametrize("builder_class", [LLVMLiteIR]) +def test_struct_variable_undefined_error( + builder_class: Type[Builder], +) -> None: + """Test error when struct variable references undefined struct.""" + builder = builder_class() + module = builder.module() + + struct_var = astx.VariableDeclaration( + name="p", + type_=StructType(struct_name="Point"), + ) + + block = astx.Block() + block.append(struct_var) + block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + + proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Int32() + ) + fn = astx.FunctionDef(prototype=proto, body=block) + module.block.append(fn) + + with pytest.raises(Exception, match=r"Struct 'Point' not defined"): + builder.build(module, output_file="/tmp/test") + + +@pytest.mark.parametrize("builder_class", [LLVMLiteIR]) +def test_struct_variable_with_value_error( + builder_class: Type[Builder], +) -> None: + """Test error when struct variable has value initialization.""" + builder = builder_class() + module = builder.module() + + struct_def = astx.StructDefStmt( + name="Point", + attributes=[ + astx.VariableDeclaration("x", astx.Int32()), + astx.VariableDeclaration("y", astx.Int32()), + ], + ) + + struct_var = astx.VariableDeclaration( + name="p", + type_=StructType(struct_name="Point"), + value=astx.LiteralInt32(42), + ) + + block = astx.Block() + block.append(struct_def) + block.append(struct_var) + block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + + proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Int32() + ) + fn = astx.FunctionDef(prototype=proto, body=block) + module.block.append(fn) + + with pytest.raises( + Exception, match=r"Struct initialization with values not yet supported" + ): + builder.build(module, output_file="/tmp/test")