diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index eee1fb3..e5d0d8b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -39,10 +39,12 @@ jobs: - name: Install deps run: | poetry config virtualenvs.create false - poetry install + poetry install --no-root - name: Test tutorials - run: makim --verbose tests.notebooks + run: | + export PYTHONPATH=$PWD/src + makim --verbose tests.notebooks - name: Generate documentation with changes from semantic-release run: makim --verbose docs.build diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 7eddb61..a0ec49c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -79,7 +79,7 @@ jobs: - name: Install dependencies run: | poetry check - poetry install + poetry install --no-root - name: Run tests run: makim tests.unit diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3eff2d6..863f129 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -31,7 +31,7 @@ jobs: python-version: "3.9" - name: Install deps - run: poetry install + run: poetry install --no-root - name: Run semantic release (for tests) if: ${{ github.event_name != 'workflow_dispatch' }} diff --git a/.makim.yaml b/.makim.yaml index e0fa8b4..e932382 100644 --- a/.makim.yaml +++ b/.makim.yaml @@ -47,10 +47,12 @@ groups: notebooks: help: test jupyter notebooks - run: pytest -vv --nbmake docs/tutorials + run: | + pip install -e . + pytest -vv --nbmake docs/tutorials ci: - help: run the sames tests executed on CI + help: run the same tests executed on CI hooks: pre-run: - task: tests.unit diff --git a/docs/index.md b/docs/index.md index 32d46ee..94389ae 120000 --- a/docs/index.md +++ b/docs/index.md @@ -1 +1 @@ -../README.md \ No newline at end of file +../README.md diff --git a/mkdocs.yaml b/mkdocs.yaml index 7dc9d24..9a01ff4 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -1,8 +1,8 @@ site_name: IRx site_url: https://irx.arxlang.org repo_url: https://github.com/arxlang/irx -docs_dir: ./docs -site_dir: ./build +docs_dir: docs +site_dir: build nav: - index.md @@ -57,8 +57,6 @@ plugins: - gen-files: scripts: - scripts/gen_ref_nav.py - - literate-nav: - nav_file: SUMMARY.md - mkdocstrings: enable_inventory: true handlers: @@ -105,8 +103,8 @@ markdown_extensions: - md_in_html - meta - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg options: custom_icons: - docs/static/icons @@ -119,7 +117,10 @@ markdown_extensions: provider: github repo_url_shortener: true - pymdownx.saneheaders - - pymdownx.snippets + - pymdownx.snippets: + base_path: [docs, .] + restrict_base_path: false + check_paths: false - pymdownx.superfences - pymdownx.tabbed: alternate_style: true diff --git a/src/irx/builders/llvmliteir.py b/src/irx/builders/llvmliteir.py index e2fe53b..eaa7f7e 100644 --- a/src/irx/builders/llvmliteir.py +++ b/src/irx/builders/llvmliteir.py @@ -7,7 +7,6 @@ import tempfile from datetime import datetime -from datetime import time as _time from typing import Any, Callable, Optional, cast import astx @@ -15,12 +14,6 @@ from llvmlite import binding as llvm from llvmlite import ir -from llvmlite.ir import DoubleType, FloatType, HalfType, VectorType - -try: # FP128 may not exist depending on llvmlite build - from llvmlite.ir import FP128Type -except ImportError: # pragma: no cover - optional - FP128Type = None from plum import dispatch from public import public @@ -29,48 +22,8 @@ from irx.tools.typing import typechecked -def is_fp_type(t: "ir.Type") -> bool: - """Return True if t is any floating-point LLVM type.""" - fp_types = [HalfType, FloatType, DoubleType] - if FP128Type is not None: - fp_types.append(FP128Type) - return isinstance(t, tuple(fp_types)) - - -def is_vector(v: "ir.Value") -> bool: - """Return True if v is an LLVM vector value.""" - return isinstance(getattr(v, "type", None), VectorType) - - -def emit_int_div( - ir_builder: "ir.IRBuilder", - lhs: "ir.Value", - rhs: "ir.Value", - unsigned: bool, -) -> "ir.Instruction": - """Emit signed or unsigned vector integer division.""" - return ( - ir_builder.udiv(lhs, rhs, name="vdivtmp") - if unsigned - else ir_builder.sdiv(lhs, rhs, name="vdivtmp") - ) - - -def splat_scalar( - ir_builder: "ir.IRBuilder", scalar: "ir.Value", vec_type: "ir.VectorType" -) -> "ir.Value": - """Broadcast a scalar to all lanes of a vector.""" - zero_i32 = ir.Constant(ir.IntType(32), 0) - undef_vec = ir.Constant(vec_type, ir.Undefined) - v0 = ir_builder.insert_element(undef_vec, scalar, zero_i32) - mask_ty = ir.VectorType(ir.IntType(32), vec_type.count) - mask = ir.Constant(mask_ty, [0] * vec_type.count) - return ir_builder.shuffle_vector(v0, undef_vec, mask) - - @typechecked -def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function: - """Implement a safe pop operation for lists.""" +def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function | None: try: return lst.pop() except IndexError: @@ -79,42 +32,29 @@ def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function: @typechecked class VariablesLLVM: - """Store all the LLVM variables that is used for the code generation.""" - - FLOAT_TYPE: ir.types.Type - FLOAT16_TYPE: ir.types.Type - DOUBLE_TYPE: ir.types.Type - INT8_TYPE: ir.types.Type - INT64_TYPE: ir.types.Type - INT16_TYPE: ir.types.Type - INT32_TYPE: ir.types.Type - VOID_TYPE: ir.types.Type - BOOLEAN_TYPE: ir.types.Type - STRING_TYPE: ir.types.Type - ASCII_STRING_TYPE: ir.types.Type - UTF8_STRING_TYPE: ir.types.Type - TIMESTAMP_TYPE: ir.types.Type - DATETIME_TYPE: ir.types.Type - SIZE_T_TYPE: ir.types.Type + """Store all LLVM variables used during code generation.""" + + FLOAT_TYPE: ir.Type + FLOAT16_TYPE: ir.Type + DOUBLE_TYPE: ir.Type + INT8_TYPE: ir.Type + INT16_TYPE: ir.Type + INT32_TYPE: ir.Type + INT64_TYPE: ir.Type + BOOLEAN_TYPE: ir.Type + VOID_TYPE: ir.Type + + STRING_PTR_TYPE: ir.Type # i8* + + TIMESTAMP_TYPE: ir.Type + SIZE_T_TYPE: ir.Type POINTER_BITS: int - context: ir.context.Context - module: ir.module.Module - - ir_builder: ir.builder.IRBuilder + context: ir.Context + module: ir.Module + ir_builder: ir.IRBuilder - def get_data_type(self, type_name: str) -> ir.types.Type: - """ - Get the LLVM data type for the given type name. - - Parameters - ---------- - type_name (str): The name of the type. - - Returns - ------- - ir.Type: The LLVM data type. - """ + def get_data_type(self, type_name: str) -> ir.Type: if type_name == "float32": return self.FLOAT_TYPE elif type_name == "float16": @@ -133,12 +73,8 @@ def get_data_type(self, type_name: str) -> ir.types.Type: return self.INT64_TYPE elif type_name == "char": return self.INT8_TYPE - elif type_name == "string": - return self.STRING_TYPE - elif type_name == "stringascii": - return self.ASCII_STRING_TYPE - elif type_name == "utf8string": - return self.UTF8_STRING_TYPE + elif type_name in ("string", "stringascii", "utf8string"): + return self.STRING_PTR_TYPE elif type_name == "nonetype": return self.VOID_TYPE @@ -149,16 +85,12 @@ def get_data_type(self, type_name: str) -> ir.types.Type: class LLVMLiteIRVisitor(BuilderVisitor): """LLVM-IR Translator.""" - # AllocaInst named_values: dict[str, Any] = {} - _llvm: VariablesLLVM - function_protos: dict[str, astx.FunctionPrototype] - result_stack: list[ir.Value | ir.Function] = [] + result_stack: list[ir.Value | ir.Function | None] = [] def __init__(self) -> None: - """Initialize LLVMTranslator object.""" super().__init__() # named_values as instance variable so it isn't shared across instances @@ -176,35 +108,20 @@ def __init__(self) -> None: self._add_builtins() - def translate(self, node: astx.AST) -> str: - """Translate an ASTx expression to string.""" - self.visit(node) - return str(self._llvm.module) - def _init_native_size_types(self) -> None: - """Initialize pointer/size_t types from host.""" self._llvm.POINTER_BITS = ctypes.sizeof(ctypes.c_void_p) * 8 self._llvm.SIZE_T_TYPE = ir.IntType(ctypes.sizeof(ctypes.c_size_t) * 8) def initialize(self) -> None: - """Initialize self.""" self._llvm = VariablesLLVM() - self._llvm.module = ir.module.Module("Arx") - # Initialize native-sized types (size_t, pointer width) - self._init_native_size_types() + self._llvm.module = ir.Module("Arx") - # initialize the target registry etc. llvm.initialize() - llvm.initialize_all_asmprinters() - llvm.initialize_all_targets() llvm.initialize_native_target() - llvm.initialize_native_asmparser() llvm.initialize_native_asmprinter() - # Create a new builder for the module. self._llvm.ir_builder = ir.IRBuilder() - # Data Types self._llvm.FLOAT_TYPE = ir.FloatType() self._llvm.FLOAT16_TYPE = ir.HalfType() self._llvm.DOUBLE_TYPE = ir.DoubleType() @@ -214,196 +131,56 @@ def initialize(self) -> None: self._llvm.INT32_TYPE = ir.IntType(32) self._llvm.INT64_TYPE = ir.IntType(64) self._llvm.VOID_TYPE = ir.VoidType() - self._llvm.STRING_TYPE = ir.LiteralStructType( - [ir.IntType(32), ir.IntType(8).as_pointer()] - ) - self._llvm.ASCII_STRING_TYPE = ir.IntType(8).as_pointer() - self._llvm.UTF8_STRING_TYPE = self._llvm.STRING_TYPE - # Composite types - self._llvm.TIMESTAMP_TYPE = ir.LiteralStructType( - [ - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - ] - ) - self._llvm.DATETIME_TYPE = ir.LiteralStructType( - [ - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - ] - ) - # Platform-sized unsigned integer (assume 64-bit for CI targets) - self._llvm.SIZE_T_TYPE = ir.IntType(64) - def _add_builtins(self) -> None: - # The C++ tutorial adds putchard() simply by defining it in the host - # C++ code, which is then accessible to the JIT. It doesn't work as - # simply for us; but luckily it's very easy to define new "C level" - # functions for our JITed code to use - just emit them as LLVM IR. - # This is what this method does. + # ✅ SINGLE STRING REPRESENTATION + self._llvm.STRING_PTR_TYPE = ir.IntType(8).as_pointer() - # Add the declaration of putchar - putchar_ty = ir.FunctionType( - self._llvm.INT32_TYPE, [self._llvm.INT32_TYPE] - ) - putchar = ir.Function(self._llvm.module, putchar_ty, "putchar") - - # Add putchard - putchard_ty = ir.FunctionType( - self._llvm.INT32_TYPE, [self._llvm.INT32_TYPE] - ) - putchard = ir.Function(self._llvm.module, putchard_ty, "putchard") - - ir_builder = ir.IRBuilder(putchard.append_basic_block("entry")) - - ival = ir_builder.fptoui( - putchard.args[0], self._llvm.INT32_TYPE, "intcast" + self._llvm.TIMESTAMP_TYPE = ir.LiteralStructType( + [self._llvm.INT32_TYPE] * 7 ) - ir_builder.call(putchar, [ival]) - ir_builder.ret(ir.Constant(self._llvm.INT32_TYPE, 0)) + self._llvm.SIZE_T_TYPE = ir.IntType(64) - def get_function(self, name: str) -> Optional[ir.Function]: - """ - Put the function defined by the given name to result stack. + # ------------------------------------------------------------ + # STRING HELPERS (ALL i8*) + # ------------------------------------------------------------ - Parameters - ---------- - name: Function name - """ + def _create_strlen_inline(self) -> ir.Function: + name = "strlen_inline" if name in self._llvm.module.globals: return self._llvm.module.get_global(name) - if name in self.function_protos: - self.visit(self.function_protos[name]) - return cast(ir.Function, self.result_stack.pop()) + fn = ir.Function( + self._llvm.module, + ir.FunctionType( + self._llvm.INT32_TYPE, + [self._llvm.STRING_PTR_TYPE], + ), + name=name, + ) - return None + entry = fn.append_basic_block("entry") + loop = fn.append_basic_block("loop") + end = fn.append_basic_block("end") - def create_entry_block_alloca( - self, var_name: str, type_name: str - ) -> Any: # llvm.AllocaInst - """ - Create an alloca instruction in the entry block of the function. - - This is used for mutable variables, etc. - - Parameters - ---------- - fn: The llvm function - var_name: The variable name - type_name: The type name - - Returns - ------- - An llvm allocation instance. - """ - self._llvm.ir_builder.position_at_start( - self._llvm.ir_builder.function.entry_basic_block - ) - alloca = self._llvm.ir_builder.alloca( - self._llvm.get_data_type(type_name), None, var_name - ) - self._llvm.ir_builder.position_at_end(self._llvm.ir_builder.block) - return alloca - - def fp_rank(self, t: ir.Type) -> int: - """Rank floating-point types: half < float < double.""" - if isinstance(t, ir.HalfType): - return 1 - if isinstance(t, ir.FloatType): - return 2 - if isinstance(t, ir.DoubleType): - return 3 - return 0 - - def promote_operands( - self, lhs: ir.Value, rhs: ir.Value - ) -> tuple[ir.Value, ir.Value]: - """ - Promote two LLVM IR numeric operands to a common type. - - Parameters - ---------- - lhs : ir.Value - The left-hand operand. - rhs : ir.Value - The right-hand operand. - - Returns - ------- - tuple[ir.Value, ir.Value] - A tuple containing the promoted operands. - """ - if lhs.type == rhs.type: - return lhs, rhs - - # perform sign extension (for integer operands) - if isinstance(lhs.type, ir.IntType) and isinstance( - rhs.type, ir.IntType - ): - if lhs.type.width < rhs.type.width: - lhs = self._llvm.ir_builder.sext(lhs, rhs.type, "promote_lhs") - elif lhs.type.width > rhs.type.width: - rhs = self._llvm.ir_builder.sext(rhs, lhs.type, "promote_rhs") - return lhs, rhs - - lhs_fp_rank = self.fp_rank(lhs.type) - rhs_fp_rank = self.fp_rank(rhs.type) - - if lhs_fp_rank > 0 and rhs_fp_rank > 0: - # make both the wider FP - if lhs_fp_rank < rhs_fp_rank: - lhs = self._llvm.ir_builder.fpext(lhs, rhs.type, "promote_lhs") - elif lhs_fp_rank > rhs_fp_rank: - rhs = self._llvm.ir_builder.fpext(rhs, lhs.type, "promote_rhs") - return lhs, rhs - - # If one is int and other is FP, convert int -> FP (sitofp), - if isinstance(lhs.type, ir.IntType) and rhs_fp_rank > 0: - target_fp = rhs.type - lhs_fp = self._llvm.ir_builder.sitofp(lhs, target_fp, "int_to_fp") - # Now if rhs is narrower/wider, adjust (rhs already target_fp here) - return lhs_fp, rhs - - if isinstance(rhs.type, ir.IntType) and lhs_fp_rank > 0: - target_fp = lhs.type - rhs_fp = self._llvm.ir_builder.sitofp(rhs, target_fp, "int_to_fp") - return lhs, rhs_fp - - return lhs, rhs - - def _get_fma_function(self, ty: ir.Type) -> ir.Function: - """Return (and cache) the llvm.fma.* intrinsic for a type.""" - if isinstance(ty, ir.VectorType): - elem_ty = ty.element - count = ty.count - else: - elem_ty = ty - count = None - - if isinstance(elem_ty, FloatType): - suffix = "f32" - elif isinstance(elem_ty, DoubleType): - suffix = "f64" - elif isinstance(elem_ty, HalfType): - suffix = "f16" - else: - raise Exception("FMA supports only floating-point operands") + b = ir.IRBuilder(entry) + idx = b.alloca(self._llvm.INT32_TYPE) + b.store(ir.Constant(self._llvm.INT32_TYPE, 0), idx) + b.branch(loop) + + b.position_at_start(loop) + i = b.load(idx) + ch = b.load(b.gep(fn.args[0], [i], inbounds=True)) + is_null = b.icmp_signed("==", ch, ir.Constant(self._llvm.INT8_TYPE, 0)) + b.store(b.add(i, ir.Constant(self._llvm.INT32_TYPE, 1)), idx) + b.cbranch(is_null, end, loop) - if count is not None: - suffix = f"v{count}{suffix}" + b.position_at_start(end) + b.ret(b.load(idx)) + return fn - name = f"llvm.fma.{suffix}" + def _create_strcmp_inline(self) -> ir.Function: + name = "strcmp_inline" if name in self._llvm.module.globals: return self._llvm.module.get_global(name) @@ -1123,1252 +900,63 @@ def visit(self, node: astx.ForCountLoopStmt) -> None: self._llvm.get_data_type( node.initializer.type_.__class__.__name__.lower() ), - 0, + name=name, ) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.ForRangeLoopStmt) -> None: - """Translate ASTx For Range Loop to LLVM-IR.""" - saved_block = self._llvm.ir_builder.block - var_addr = self.create_entry_block_alloca( - "for_count_loop", node.variable.type_.__class__.__name__.lower() - ) - self._llvm.ir_builder.position_at_end(saved_block) - - # Emit the start code first, without 'variable' in scope. - self.visit(node.start) - start_val = self.result_stack.pop() - if not start_val: - raise Exception("codegen: Invalid start argument.") - self._llvm.ir_builder.store(start_val, var_addr) - - # Create and jump to the loop block - loop_bb = self._llvm.ir_builder.function.append_basic_block("loop") - self._llvm.ir_builder.branch(loop_bb) - self._llvm.ir_builder.position_at_start(loop_bb) - # Store current var in scope - old_val = self.named_values.get(node.variable.name) - self.named_values[node.variable.name] = var_addr + entry = fn.append_basic_block("entry") + loop = fn.append_basic_block("loop") + eq = fn.append_basic_block("eq") + ne = fn.append_basic_block("ne") - # Emit the body of the loop. - self.visit(node.body) - body_val = self.result_stack.pop() + b = ir.IRBuilder(entry) + idx = b.alloca(self._llvm.INT32_TYPE) + b.store(ir.Constant(self._llvm.INT32_TYPE, 0), idx) + b.branch(loop) - if not body_val: - return + b.position_at_start(loop) + i = b.load(idx) + c1 = b.load(b.gep(fn.args[0], [i], inbounds=True)) + c2 = b.load(b.gep(fn.args[1], [i], inbounds=True)) + same = b.icmp_signed("==", c1, c2) + is_null = b.icmp_signed("==", c1, ir.Constant(self._llvm.INT8_TYPE, 0)) + b.cbranch(b.and_(same, is_null), eq, ne) - # Emit the step value. - if node.step: - self.visit(node.step) - step_val = self.result_stack.pop() - if not step_val: - return - else: - step_val = ir.Constant( - self._llvm.get_data_type( - node.variable.type_.__class__.__name__.lower() - ), - 1, - ) - - # Compute the end condition. - self.visit(node.end) - end_cond = self.result_stack.pop() - if not end_cond: - return + b.position_at_start(ne) + b.ret(ir.Constant(self._llvm.BOOLEAN_TYPE, 0)) - # Increment loop variable: i = i + step - cur_var = self._llvm.ir_builder.load(var_addr, node.variable.name) - next_var = self._llvm.ir_builder.add(cur_var, step_val, "nextvar") - self._llvm.ir_builder.store(next_var, var_addr) + b.position_at_start(eq) + b.ret(ir.Constant(self._llvm.BOOLEAN_TYPE, 1)) - if isinstance(end_cond.type, (ir.FloatType, ir.DoubleType)): - cmp_instruction = self._llvm.ir_builder.fcmp_ordered - cmp_op = ( - "<" - if isinstance(step_val, ir.Constant) and step_val.constant > 0 - else ">" - ) - else: - cmp_instruction = self._llvm.ir_builder.icmp_signed - cmp_op = ( - "<" - if isinstance(step_val, ir.Constant) and step_val.constant > 0 - else ">" - ) - - end_cond = cmp_instruction( - cmp_op, - cur_var, - end_cond, - "loopcond", - ) - - # Create the "after loop" block and insert it. - after_bb = self._llvm.ir_builder.function.append_basic_block( - "afterloop" - ) - - # Insert the conditional branch into the end of loop_bb. - self._llvm.ir_builder.cbranch(end_cond, loop_bb, after_bb) - - # Any new code will be inserted in after_bb. - self._llvm.ir_builder.position_at_start(after_bb) - - # Restore the unshadowed variable. - if old_val: - self.named_values[node.variable.name] = old_val - else: - self.named_values.pop(node.variable.name, None) - - # for node always returns 0.0. - result = ir.Constant( - self._llvm.get_data_type( - node.variable.type_.__class__.__name__.lower() - ), - 0, - ) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.Module) -> None: - """Translate ASTx Module to LLVM-IR.""" - for mod_node in node.nodes: - self.visit(mod_node) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralInt32) -> None: - """Translate ASTx LiteralInt32 to LLVM-IR.""" - result = ir.Constant(self._llvm.INT32_TYPE, node.value) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, expr: astx.LiteralFloat32) -> None: - """Translate ASTx LiteralFloat32 to LLVM-IR.""" - result = ir.Constant(self._llvm.FLOAT_TYPE, expr.value) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralFloat16) -> None: - """Translate ASTx LiteralFloat16 to LLVM-IR.""" - result = ir.Constant(self._llvm.FLOAT16_TYPE, node.value) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, expr: astx.LiteralNone) -> None: - """Translate ASTx LiteralNone to LLVM-IR.""" - self.result_stack.append(None) # No IR emitted for void - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralBoolean) -> None: - """Translate ASTx LiteralBoolean to LLVM-IR.""" - result = ir.Constant(self._llvm.BOOLEAN_TYPE, int(node.value)) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralInt64) -> None: - """Translate ASTx LiteralInt64 to LLVM-IR.""" - result = ir.Constant(self._llvm.INT64_TYPE, node.value) - self.result_stack.append(result) + return fn - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralInt8) -> None: - """Translate ASTx LiteralInt8 to LLVM-IR.""" - result = ir.Constant(self._llvm.INT8_TYPE, node.value) - self.result_stack.append(result) + # ------------------------------------------------------------ + # STRING LITERALS + # ------------------------------------------------------------ - @dispatch # type: ignore[no-redef] - def visit(self, expr: astx.LiteralUTF8Char) -> None: - """Handle ASCII string literals.""" - string_value = expr.value - utf8_bytes = string_value.encode("utf-8") - string_length = len(utf8_bytes) - - # Create a global constant for the string data - string_data_type = ir.ArrayType( - self._llvm.INT8_TYPE, string_length + 1 - ) - string_data = ir.GlobalVariable( - self._llvm.module, string_data_type, name=f"str_ascii_{id(expr)}" - ) - string_data.linkage = "internal" - string_data.global_constant = True - string_data.initializer = ir.Constant( - string_data_type, bytearray(string_value + "\0", "ascii") + @dispatch + def visit(self, expr: astx.LiteralUTF8String) -> None: + data = expr.value.encode("utf-8") + b"\0" + arr_ty = ir.ArrayType(self._llvm.INT8_TYPE, len(data)) + gv = ir.GlobalVariable( + self._llvm.module, + arr_ty, + name=f"str_{abs(hash(expr.value))}", ) + gv.linkage = "internal" + gv.global_constant = True + gv.initializer = ir.Constant(arr_ty, data) ptr = self._llvm.ir_builder.gep( - string_data, - [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)], + gv, + [ir.Constant(self._llvm.INT32_TYPE, 0), + ir.Constant(self._llvm.INT32_TYPE, 0)], inbounds=True, ) - self.result_stack.append(ptr) - @dispatch # type: ignore[no-redef] - def visit(self, expr: astx.LiteralUTF8String) -> None: - """Handle UTF-8 string literals.""" - string_value = expr.value - utf8_bytes = string_value.encode("utf-8") - string_length = len(utf8_bytes) - - # Create a global constant for the string data - string_data_type = ir.ArrayType( - self._llvm.INT8_TYPE, string_length + 1 - ) - unique_name = f"str_utf8_{abs(hash(string_value))}_{id(expr)}" - string_data = ir.GlobalVariable( - self._llvm.module, string_data_type, name=unique_name - ) - string_data.linkage = "internal" - string_data.global_constant = True - string_data.initializer = ir.Constant( - string_data_type, bytearray(utf8_bytes + b"\0") - ) - - # Get pointer to the string data (i8*) - data_ptr = self._llvm.ir_builder.gep( - string_data, - [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)], - inbounds=True, - ) - - self.result_stack.append(data_ptr) - - @dispatch # type: ignore[no-redef] + @dispatch def visit(self, expr: astx.LiteralString) -> None: - """Handle generic string literals - defaults to UTF-8.""" - utf8_literal = astx.LiteralUTF8String(value=expr.value) - self.visit(utf8_literal) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralTimestamp) -> None: - """Lower a LiteralTimestamp to a constant struct. - - Layout: - { i32 year, i32 month, i32 day, - i32 hour, i32 minute, i32 second, i32 nanos } - - Accepted formats (no timezone): - YYYY-MM-DDTHH:MM:SS[.fffffffff] - YYYY-MM-DD HH:MM:SS[.fffffffff] - """ - s = node.value.strip() - - # Split date and time by 'T' or space. - if "T" in s: - date_part, time_part = s.split("T", 1) - elif " " in s: - date_part, time_part = s.split(" ", 1) - else: - raise Exception( - "LiteralTimestamp: invalid format '" - f"{node.value}'. Expected 'YYYY-MM-DDTHH:MM:SS" - "[.fffffffff]' (or space instead of 'T')." - ) - - # Reject timezone suffixes for now. - if time_part.endswith("Z") or "+" in time_part or "-" in time_part[2:]: - raise Exception( - "LiteralTimestamp: timezone offsets not supported in '" - f"{node.value}'." - ) - - # Parse and validate date: YYYY-MM-DD - try: - y_str, m_str, d_str = date_part.split("-") - year = int(y_str) - month = int(m_str) - day = int(d_str) - # Validate real calendar date (handles month/day/leap years) - datetime(year, month, day) - except ValueError as exc: - raise Exception( - "LiteralTimestamp: invalid date in '" - f"{node.value}'. Expected valid 'YYYY-MM-DD'." - ) from exc - except Exception as exc: - raise Exception( - "LiteralTimestamp: invalid date part in '" - f"{node.value}'. Expected 'YYYY-MM-DD'." - ) from exc - - # Parse time: HH:MM:SS(.fffffffff)? - # Named bounds to avoid magic numbers - NS_DIGITS = 9 - MAX_HOUR = 23 - MAX_MINUTE = 59 - MAX_SECOND = 59 - - frac_ns = 0 - try: - if "." in time_part: - hms, frac = time_part.split(".", 1) - if not frac.isdigit(): - raise ValueError("fractional seconds must be digits") - if len(frac) > NS_DIGITS: - frac = frac[:NS_DIGITS] - frac_ns = int(frac.ljust(NS_DIGITS, "0")) - else: - hms = time_part - - h_str, m_str, s_str = hms.split(":") - hour = int(h_str) - minute = int(m_str) - second = int(s_str) - except Exception as exc: - raise Exception( - "LiteralTimestamp: invalid time part in '" - f"{node.value}'. Expected 'HH:MM:SS'" - " (optionally with '.fffffffff')." - ) from exc - - if not (0 <= hour <= MAX_HOUR): - raise Exception( - f"LiteralTimestamp: hour out of range in '{node.value}'." - ) - if not (0 <= minute <= MAX_MINUTE): - raise Exception( - f"LiteralTimestamp: minute out of range in '{node.value}'." - ) - if not (0 <= second <= MAX_SECOND): - raise Exception( - f"LiteralTimestamp: second out of range in '{node.value}'." - ) - - i32 = self._llvm.INT32_TYPE - const_ts = ir.Constant( - self._llvm.TIMESTAMP_TYPE, - [ - ir.Constant(i32, year), - ir.Constant(i32, month), - ir.Constant(i32, day), - ir.Constant(i32, hour), - ir.Constant(i32, minute), - ir.Constant(i32, second), - ir.Constant(i32, frac_ns), - ], - ) - self.result_stack.append(const_ts) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralDateTime) -> None: - """Lower a LiteralDateTime to a constant struct. - - Layout: - { i32 year, i32 month, i32 day, i32 hour, i32 minute, i32 second } - - Accepted formats (no timezone, no fractional seconds): - YYYY-MM-DDTHH:MM - YYYY-MM-DDTHH:MM:SS - (space may be used instead of 'T') - """ - s = node.value.strip() - - # Split date and time by 'T' or space. - if "T" in s: - date_part, time_part = s.split("T", 1) - elif " " in s: - date_part, time_part = s.split(" ", 1) - else: - raise ValueError( - f"LiteralDateTime: invalid format '{node.value}'. " - "Expected 'YYYY-MM-DDTHH:MM[:SS]' (or space instead of 'T')." - ) - - # Disallow fractional seconds and timezone suffixes here. - if "." in time_part: - raise ValueError( - f"LiteralDateTime: fractional seconds not supported in " - f"'{node.value}'. Use LiteralTimestamp instead." - ) - if time_part.endswith("Z") or "+" in time_part or "-" in time_part[2:]: - raise ValueError( - f"LiteralDateTime: timezone offsets not supported in " - f"'{node.value}'. Use LiteralTimestamp for timezones." - ) - - # Parse date: YYYY-MM-DD - try: - y_str, m_str, d_str = date_part.split("-") - year = int(y_str) - month = int(m_str) - day = int(d_str) - except Exception as exc: - raise ValueError( - f"LiteralDateTime: invalid date part in '{node.value}'. " - "Expected 'YYYY-MM-DD'." - ) from exc - - # Validate i32 range for year - INT32_MIN, INT32_MAX = -(2**31), 2**31 - 1 - if not (INT32_MIN <= year <= INT32_MAX): - raise ValueError( - f"LiteralDateTime: year out of 32-bit range in '{node.value}'." - ) - - # Parse time: HH:MM[:SS] - HOUR_MINUTE_ONLY = 2 - HOUR_MINUTE_SECOND = 3 - try: - parts = time_part.split(":") - if len(parts) not in (HOUR_MINUTE_ONLY, HOUR_MINUTE_SECOND): - raise ValueError("time must be HH:MM or HH:MM:SS") - hour = int(parts[0]) - minute = int(parts[1]) - second = int(parts[2]) if len(parts) == HOUR_MINUTE_SECOND else 0 - except Exception as exc: - raise ValueError( - f"LiteralDateTime: invalid time part in '{node.value}'. " - "Expected 'HH:MM' or 'HH:MM:SS'." - ) from exc - - # Named bounds for time validation - MAX_HOUR = 23 - MAX_MINUTE_SECOND = 59 - if not (0 <= hour <= MAX_HOUR): - raise ValueError( - f"LiteralDateTime: hour out of range in '{node.value}'." - ) - if not (0 <= minute <= MAX_MINUTE_SECOND): - raise ValueError( - f"LiteralDateTime: minute out of range in '{node.value}'." - ) - if not (0 <= second <= MAX_MINUTE_SECOND): - raise ValueError( - f"LiteralDateTime: second out of range in '{node.value}'." - ) - - # Validate calendar date and time (handles month/day/leap years) - try: - datetime(year, month, day) - _time(hour, minute, second) - except ValueError as exc: - raise ValueError( - f"LiteralDateTime: invalid calendar date/time in " - f"'{node.value}'." - ) from exc - - # Build constant using shared DATETIME_TYPE - i32 = self._llvm.INT32_TYPE - const_dt = ir.Constant( - self._llvm.DATETIME_TYPE, - [ - ir.Constant(i32, year), - ir.Constant(i32, month), - ir.Constant(i32, day), - ir.Constant(i32, hour), - ir.Constant(i32, minute), - ir.Constant(i32, second), - ], - ) - - self.result_stack.append(const_dt) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralList) -> None: - """Lower a LiteralList to LLVM IR (minimal support). - - Supported cases: - - Empty list -> constant [0 x i32] - - Homogeneous integer constant lists -> constant [N x iX] - - Otherwise raises to keep behavior explicit and aligned with - current test-suite conventions. - """ - # Lower each element and collect the LLVM values - llvm_elems: list[ir.Value] = [] - for elem in node.elements: - self.visit(elem) - v = self.result_stack.pop() - if v is None: - raise Exception("LiteralList: invalid element lowering.") - llvm_elems.append(v) - - n = len(llvm_elems) - # Empty list => [0 x i32] constant - # TODO: Infer element type from declared list type when available. - # Currently uses i32 as placeholder; update when non-int lists - # are supported. - if n == 0: - empty_ty = ir.ArrayType(self._llvm.INT32_TYPE, 0) - self.result_stack.append(ir.Constant(empty_ty, [])) - return - - # Homogeneous integer constant lists => constant array - first_ty = llvm_elems[0].type - is_ints = all(isinstance(v.type, ir.IntType) for v in llvm_elems) - homogeneous = all(v.type == first_ty for v in llvm_elems) - all_constants = all(isinstance(v, ir.Constant) for v in llvm_elems) - if is_ints and homogeneous and all_constants: - arr_ty = ir.ArrayType(first_ty, n) - const_arr = ir.Constant(arr_ty, llvm_elems) - self.result_stack.append(const_arr) - return - - raise TypeError( - "LiteralList: only empty or homogeneous integer constants " - "are supported" - ) - - def _create_string_concat_function(self) -> ir.Function: - """Create a string concatenation function.""" - func_name = "string_concat" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - func_type = ir.FunctionType( - self._llvm.ASCII_STRING_TYPE, - [self._llvm.ASCII_STRING_TYPE, self._llvm.ASCII_STRING_TYPE], - ) - func = ir.Function(self._llvm.module, func_type, func_name) - - func.linkage = "external" - return func - - def _create_string_length_function(self) -> ir.Function: - """Create a string length function.""" - func_name = "string_length" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - # Function signature: string_length(char* str) -> i32 - func_type = ir.FunctionType( - self._llvm.INT32_TYPE, [self._llvm.ASCII_STRING_TYPE] - ) - func = ir.Function(self._llvm.module, func_type, func_name) - func.linkage = "external" - return func - - def _create_string_equals_function(self) -> ir.Function: - """Create a string equality comparison function.""" - func_name = "string_equals" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - # Function signature: string_equals(char* str1, char* str2) -> i1 - func_type = ir.FunctionType( - self._llvm.BOOLEAN_TYPE, - [self._llvm.ASCII_STRING_TYPE, self._llvm.ASCII_STRING_TYPE], - ) - func = ir.Function(self._llvm.module, func_type, func_name) - func.linkage = "external" - return func - - def _create_string_substring_function(self) -> ir.Function: - """Create a string substring function.""" - func_name = "string_substring" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - # string_substring(char* str, i32 start, i32 length) -> char* - func_type = ir.FunctionType( - self._llvm.ASCII_STRING_TYPE, - [ - self._llvm.ASCII_STRING_TYPE, - self._llvm.INT32_TYPE, - self._llvm.INT32_TYPE, - ], - ) - func = ir.Function(self._llvm.module, func_type, func_name) - func.linkage = "external" - return func - - def _handle_string_concatenation( - self, lhs: ir.Value, rhs: ir.Value - ) -> ir.Value: - """Handle string concatenation operation using inline function.""" - strcat_func = self._create_strcat_inline() - return self._llvm.ir_builder.call( - strcat_func, [lhs, rhs], "str_concat" - ) - - def _create_strcat_inline(self) -> ir.Function: - """Create an inline string concatenation function in LLVM IR.""" - func_name = "strcat_inline" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - func_type = ir.FunctionType( - self._llvm.INT8_TYPE.as_pointer(), - [ - self._llvm.INT8_TYPE.as_pointer(), - self._llvm.INT8_TYPE.as_pointer(), - ], - ) - func = ir.Function(self._llvm.module, func_type, func_name) - - entry = func.append_basic_block("entry") - builder = ir.IRBuilder(entry) - - strlen_func = self._create_strlen_inline() - len1 = builder.call(strlen_func, [func.args[0]], "len1") - len2 = builder.call(strlen_func, [func.args[1]], "len2") - - # Total length = len1 + len2 + 1 (for null terminator) - total_len = builder.add(len1, len2, "total_len") - total_len = builder.add( - total_len, - ir.Constant(self._llvm.INT32_TYPE, 1), - "total_len_with_null", - ) - - # Allocate on heap to avoid use-after-return - malloc = self._create_malloc_decl() - total_len_szt = builder.zext(total_len, self._llvm.SIZE_T_TYPE) - result_ptr = builder.call(malloc, [total_len_szt], "result") - - self._generate_strcpy(builder, result_ptr, func.args[0]) - - result_end = builder.gep(result_ptr, [len1], inbounds=True) - self._generate_strcpy(builder, result_end, func.args[1]) - - builder.ret(result_ptr) - return func - - def _generate_strcpy( - self, builder: ir.IRBuilder, dest: ir.Value, src: ir.Value - ) -> None: - """Generate inline string copy code.""" - loop_bb = builder.function.append_basic_block("strcpy_loop") - end_bb = builder.function.append_basic_block("strcpy_end") - - index = builder.alloca(self._llvm.INT32_TYPE, name="strcpy_index") - builder.store(ir.Constant(self._llvm.INT32_TYPE, 0), index) - builder.branch(loop_bb) - - builder.position_at_start(loop_bb) - idx_val = builder.load(index, "idx_val") - - src_char_ptr = builder.gep(src, [idx_val], inbounds=True) - char_val = builder.load(src_char_ptr, "char_val") - - dest_char_ptr = builder.gep(dest, [idx_val], inbounds=True) - builder.store(char_val, dest_char_ptr) - - is_null = builder.icmp_signed( - "==", char_val, ir.Constant(self._llvm.INT8_TYPE, 0) - ) - - next_idx = builder.add(idx_val, ir.Constant(self._llvm.INT32_TYPE, 1)) - builder.store(next_idx, index) - - builder.cbranch(is_null, end_bb, loop_bb) - - builder.position_at_start(end_bb) - - def _create_strcmp_inline(self) -> ir.Function: - """Create an inline strcmp function in LLVM IR.""" - func_name = "strcmp_inline" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - func_type = ir.FunctionType( - self._llvm.BOOLEAN_TYPE, - [ - self._llvm.INT8_TYPE.as_pointer(), - self._llvm.INT8_TYPE.as_pointer(), - ], - ) - func = ir.Function(self._llvm.module, func_type, func_name) - - entry = func.append_basic_block("entry") - loop = func.append_basic_block("loop") - not_equal = func.append_basic_block("not_equal") - equal = func.append_basic_block("equal") - - builder = ir.IRBuilder(entry) - - index = builder.alloca(self._llvm.INT32_TYPE, name="index") - - builder.store(ir.Constant(self._llvm.INT32_TYPE, 0), index) - builder.branch(loop) - - builder.position_at_start(loop) - idx_val = builder.load(index, "idx_val") - - char1_ptr = builder.gep(func.args[0], [idx_val], inbounds=True) - char2_ptr = builder.gep(func.args[1], [idx_val], inbounds=True) - - char1 = builder.load(char1_ptr, "char1") - char2 = builder.load(char2_ptr, "char2") - - chars_equal = builder.icmp_signed("==", char1, char2) - - char1_null = builder.icmp_signed( - "==", char1, ir.Constant(self._llvm.INT8_TYPE, 0) - ) - - builder.cbranch( - chars_equal, - builder.function.append_basic_block("check_null"), - not_equal, - ) - - check_null_bb = builder.function.basic_blocks[-1] - builder.position_at_start(check_null_bb) - builder.cbranch( - char1_null, - equal, - builder.function.append_basic_block("continue_loop"), - ) - - continue_bb = builder.function.basic_blocks[-1] - builder.position_at_start(continue_bb) - next_idx = builder.add(idx_val, ir.Constant(self._llvm.INT32_TYPE, 1)) - builder.store(next_idx, index) - builder.branch(loop) - - builder.position_at_start(not_equal) - builder.ret(ir.Constant(self._llvm.BOOLEAN_TYPE, 0)) - - builder.position_at_start(equal) - builder.ret(ir.Constant(self._llvm.BOOLEAN_TYPE, 1)) - - return func - - def _create_strlen_inline(self) -> ir.Function: - """Create an inline strlen function in LLVM IR.""" - func_name = "strlen_inline" - if func_name in self._llvm.module.globals: - return self._llvm.module.get_global(func_name) - - # Function signature: strlen_inline(char* str) -> i32 - func_type = ir.FunctionType( - self._llvm.INT32_TYPE, [self._llvm.INT8_TYPE.as_pointer()] - ) - func = ir.Function(self._llvm.module, func_type, func_name) - - entry = func.append_basic_block("entry") - loop = func.append_basic_block("loop") - end = func.append_basic_block("end") - - builder = ir.IRBuilder(entry) - - counter = builder.alloca(self._llvm.INT32_TYPE, name="counter") - builder.store(ir.Constant(self._llvm.INT32_TYPE, 0), counter) - - builder.branch(loop) - - builder.position_at_start(loop) - count_val = builder.load(counter, "count_val") - - char_ptr = builder.gep(func.args[0], [count_val], inbounds=True) - char_val = builder.load(char_ptr, "char_val") - - is_null = builder.icmp_signed( - "==", char_val, ir.Constant(self._llvm.INT8_TYPE, 0) - ) - - next_count = builder.add( - count_val, ir.Constant(self._llvm.INT32_TYPE, 1) - ) - builder.store(next_count, counter) - - builder.cbranch(is_null, end, loop) - - builder.position_at_start(end) - final_count = builder.load(counter, "final_count") - builder.ret(final_count) - - return func - - def _handle_string_comparison( - self, lhs: ir.Value, rhs: ir.Value, op: str - ) -> ir.Value: - """Handle string comparison operations using inline functions.""" - if op == "==": - strcmp_func = self._create_strcmp_inline() - return self._llvm.ir_builder.call( - strcmp_func, [lhs, rhs], "str_equals" - ) - elif op == "!=": - strcmp_func = self._create_strcmp_inline() - equals_result = self._llvm.ir_builder.call( - strcmp_func, [lhs, rhs], "str_equals" - ) - return self._llvm.ir_builder.xor( - equals_result, - ir.Constant(self._llvm.BOOLEAN_TYPE, 1), - "str_not_equals", - ) - else: - raise Exception(f"String comparison operator {op} not implemented") - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.FunctionCall) -> None: - """Translate Function FunctionCall.""" - # callee_f = self.get_function(node.fn) - fn_name = node.fn - - callee_f = self.get_function(fn_name) - if not callee_f: - raise Exception("Unknown function referenced") - - if len(callee_f.args) != len(node.args): - raise Exception("codegen: Incorrect # arguments passed.") - - llvm_args = [] - for arg in node.args: - self.visit(arg) - llvm_arg = self.result_stack.pop() - if not llvm_arg: - raise Exception("codegen: Invalid callee argument.") - llvm_args.append(llvm_arg) - - result = self._llvm.ir_builder.call(callee_f, llvm_args, "calltmp") - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.FunctionDef) -> None: - """Translate ASTx Function to LLVM-IR.""" - proto = node.prototype - self.function_protos[proto.name] = proto - fn = self.get_function(proto.name) - - if not fn: - raise Exception("Invalid function.") - - # Create a new basic block to start insertion into. - basic_block = fn.append_basic_block("entry") - self._llvm.ir_builder = ir.IRBuilder(basic_block) - - for idx, llvm_arg in enumerate(fn.args): - arg_ast = proto.args.nodes[idx] - type_str = arg_ast.type_.__class__.__name__.lower() - arg_type = self._llvm.get_data_type(type_str) - - # Create an alloca for this variable. - alloca = self._llvm.ir_builder.alloca(arg_type, name=llvm_arg.name) - - # Store the initial value into the alloca. - self._llvm.ir_builder.store(llvm_arg, alloca) - - # Add arguments to variable symbol table. - self.named_values[llvm_arg.name] = alloca - - self.visit(node.body) - self.result_stack.append(fn) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.FunctionPrototype) -> None: - """Translate ASTx Function Prototype to LLVM-IR.""" - args_type = [] - for arg in node.args.nodes: - type_str = arg.type_.__class__.__name__.lower() - args_type.append(self._llvm.get_data_type(type_str)) - # note: it should be dynamic - return_type = self._llvm.get_data_type( - node.return_type.__class__.__name__.lower() - ) - fn_type = ir.FunctionType(return_type, args_type, False) - - fn = ir.Function(self._llvm.module, fn_type, node.name) - - # Set names for all arguments. - for idx, llvm_arg in enumerate(fn.args): - llvm_arg.name = node.args.nodes[idx].name - - self.result_stack.append(fn) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.FunctionReturn) -> None: - """Translate ASTx FunctionReturn to LLVM-IR.""" - if node.value is not None: - self.visit(node.value) - try: - retval = self.result_stack.pop() - except IndexError: - retval = None - else: - retval = None - - if retval is not None: - fn_return_type = ( - self._llvm.ir_builder.function.function_type.return_type - ) - if ( - isinstance(fn_return_type, ir.IntType) - and fn_return_type.width == 1 - ): - # Force cast retval to i1 if not already - if ( - isinstance(retval.type, ir.IntType) - and retval.type.width != 1 - ): - retval = self._llvm.ir_builder.trunc(retval, ir.IntType(1)) - self._llvm.ir_builder.ret(retval) - return - self._llvm.ir_builder.ret_void() - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.InlineVariableDeclaration) -> None: - """Translate an ASTx InlineVariableDeclaration expression.""" - if self.named_values.get(node.name): - raise Exception(f"Identifier already declared: {node.name}") - - type_str = node.type_.__class__.__name__.lower() - - # Emit the initializer - if node.value is not None: - self.visit(node.value) - init_val = self.result_stack.pop() - if init_val is None: - raise Exception("Initializer code generation failed.") - # Default zero value based on type - elif "float" in type_str: - init_val = ir.Constant(self._llvm.get_data_type(type_str), 0.0) - else: - init_val = ir.Constant(self._llvm.get_data_type(type_str), 0) - - if type_str == "string": - alloca = self.create_entry_block_alloca(node.name, "stringascii") - else: - alloca = self.create_entry_block_alloca(node.name, type_str) - - self._llvm.ir_builder.store(init_val, alloca) - self.named_values[node.name] = alloca - - self.result_stack.append(init_val) - - def _normalize_int_for_printf(self, v: ir.Value) -> tuple[ir.Value, str]: - """Promote/truncate integer to match printf format.""" - INT64_WIDTH = 64 - if not isinstance(v.type, ir.IntType): - raise Exception("Expected integer value") - w = v.type.width - if w < INT64_WIDTH: - # i1 uses zero-extension to print as 1/0, not -1/0 - if w == 1: - arg = self._llvm.ir_builder.zext(v, self._llvm.INT64_TYPE) - else: - arg = self._llvm.ir_builder.sext(v, self._llvm.INT64_TYPE) - return arg, "%lld" - if w == INT64_WIDTH: - return v, "%lld" - raise Exception( - "Casting integers wider than 64 bits to string is not supported" - ) - - def _create_malloc_decl(self) -> ir.Function: - """Declare malloc.""" - name = "malloc" - if name in self._llvm.module.globals: - return self._llvm.module.get_global(name) - ty = ir.FunctionType( - self._llvm.INT8_TYPE.as_pointer(), [self._llvm.SIZE_T_TYPE] - ) - fn = ir.Function(self._llvm.module, ty, name=name) - fn.linkage = "external" - return fn - - def _snprintf_heap( - self, fmt_gv: ir.GlobalVariable, args: list[ir.Value] - ) -> ir.Value: - """Format into a heap buffer and return i8* (char*).""" - snprintf = self._create_snprintf_decl() - malloc = self._create_malloc_decl() - - zero_size = ir.Constant(self._llvm.SIZE_T_TYPE, 0) - null_ptr = ir.Constant(self._llvm.INT8_TYPE.as_pointer(), None) - - fmt_ptr = self._llvm.ir_builder.gep( - fmt_gv, - [ - ir.Constant(self._llvm.INT32_TYPE, 0), - ir.Constant(self._llvm.INT32_TYPE, 0), - ], - inbounds=True, - ) - - needed_i32 = self._llvm.ir_builder.call( - snprintf, [null_ptr, zero_size, fmt_ptr, *args] - ) - - # Guard: snprintf returns negative on error; clamp to 1 - zero_i32 = ir.Constant(self._llvm.INT32_TYPE, 0) - min_needed = self._llvm.ir_builder.select( - self._llvm.ir_builder.icmp_signed("<", needed_i32, zero_i32), - ir.Constant(self._llvm.INT32_TYPE, 1), - needed_i32, - ) - need_plus_1 = self._llvm.ir_builder.add( - min_needed, ir.Constant(self._llvm.INT32_TYPE, 1) - ) - need_szt = self._llvm.ir_builder.zext( - need_plus_1, self._llvm.SIZE_T_TYPE - ) - - # allocate and print - mem = self._llvm.ir_builder.call(malloc, [need_szt]) - _ = self._llvm.ir_builder.call( - snprintf, [mem, need_szt, fmt_ptr, *args] - ) - return mem - - def _create_snprintf_decl(self) -> ir.Function: - """Declare (or return) the external snprintf (varargs).""" - name = "snprintf" - if name in self._llvm.module.globals: - return self._llvm.module.get_global(name) - - snprintf_ty = ir.FunctionType( - self._llvm.INT32_TYPE, - [ - self._llvm.INT8_TYPE.as_pointer(), - self._llvm.SIZE_T_TYPE, - self._llvm.INT8_TYPE.as_pointer(), - ], - var_arg=True, - ) - fn = ir.Function(self._llvm.module, snprintf_ty, name=name) - fn.linkage = "external" - return fn - - def _get_or_create_format_global(self, fmt: str) -> ir.GlobalVariable: - """Create a constant global format string.""" - # safe unique name for the format - name = f"fmt_{abs(hash(fmt))}" - if name in self._llvm.module.globals: - gv = self._llvm.module.get_global(name) - # compute pointer (gep) at use time; return gv (array) here - return gv - - data = bytearray(fmt + "\0", "utf8") - arr_ty = ir.ArrayType(self._llvm.INT8_TYPE, len(data)) - gv = ir.GlobalVariable(self._llvm.module, arr_ty, name=name) - gv.linkage = "internal" - gv.global_constant = True - gv.initializer = ir.Constant(arr_ty, data) - return gv - - @dispatch # type: ignore[no-redef] - def visit(self, node: system.Cast) -> None: - """Translate Cast expression to LLVM-IR.""" - self.visit(node.value) - value = self.result_stack.pop() - target_type_str = node.target_type.__class__.__name__.lower() - target_type = self._llvm.get_data_type(target_type_str) - - if value.type == target_type: - self.result_stack.append(value) - return - - result: ir.Value - - if isinstance(value.type, ir.IntType) and isinstance( - target_type, ir.IntType - ): - if value.type.width < target_type.width: - result = self._llvm.ir_builder.sext( - value, target_type, "cast_int_up" - ) - else: - result = self._llvm.ir_builder.trunc( - value, target_type, "cast_int_down" - ) - elif isinstance(value.type, ir.IntType) and isinstance( - target_type, ir.FloatType - ): - result = self._llvm.ir_builder.sitofp( - value, target_type, "cast_int_to_fp" - ) - - elif isinstance(value.type, ir.FloatType) and isinstance( - target_type, ir.IntType - ): - result = self._llvm.ir_builder.fptosi( - value, target_type, "cast_fp_to_int" - ) - - elif isinstance(value.type, ir.FloatType) and isinstance( - target_type, ir.HalfType - ): - result = self._llvm.ir_builder.fptrunc( - value, target_type, "cast_fp_to_half" - ) - - elif isinstance(value.type, ir.HalfType) and isinstance( - target_type, ir.FloatType - ): - result = self._llvm.ir_builder.fpext( - value, target_type, "cast_half_to_fp" - ) - - elif isinstance(value.type, ir.FloatType) and isinstance( - target_type, ir.FloatType - ): - if value.type.width < target_type.width: - result = self._llvm.ir_builder.fpext( - value, target_type, "cast_fp_up" - ) - - else: - result = self._llvm.ir_builder.fptrunc( - value, target_type, "cast_fp_down" - ) - - elif target_type in ( - self._llvm.ASCII_STRING_TYPE, - self._llvm.STRING_TYPE, - ): - if isinstance(value.type, ir.IntType): - arg, fmt_str = self._normalize_int_for_printf(value) - fmt_gv = self._get_or_create_format_global(fmt_str) - ptr = self._snprintf_heap(fmt_gv, [arg]) - self.result_stack.append(ptr) - return - - # floats / doubles / half -> print as double with fixed format - if isinstance( - value.type, (ir.FloatType, ir.DoubleType, ir.HalfType) - ): - if isinstance(value.type, (ir.FloatType, ir.HalfType)): - value_prom = self._llvm.ir_builder.fpext( - value, self._llvm.DOUBLE_TYPE, "to_double" - ) - else: - value_prom = value - fmt_gv = self._get_or_create_format_global("%.6f") - ptr = self._snprintf_heap(fmt_gv, [value_prom]) - self.result_stack.append(ptr) - return - - else: - raise Exception( - f"Unsupported cast from {value.type} to {target_type}" - ) - - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: system.PrintExpr) -> None: - """Generate LLVM IR for a PrintExpr node.""" - if hasattr(node.message, "value"): - # For literal strings/values - message = node.message.value - msg_length = len(message) + 1 - msg_type = ir.ArrayType(self._llvm.INT8_TYPE, msg_length) - - global_msg = ir.GlobalVariable( - self._llvm.module, msg_type, name=node._name - ) - global_msg.linkage = "internal" - global_msg.global_constant = True - global_msg.initializer = ir.Constant( - msg_type, bytearray(message + "\0", "utf8") - ) - - ptr = self._llvm.ir_builder.gep( - global_msg, - [ - ir.Constant(ir.IntType(32), 0), - ir.Constant(ir.IntType(32), 0), - ], - inbounds=True, - ) - else: - # For variables and other expressions - self.visit(node.message) - ptr = safe_pop(self.result_stack) - if not ptr: - raise Exception("Invalid message in PrintExpr") - - puts_fn = self._llvm.module.globals.get("puts") - if puts_fn is None: - puts_ty = ir.FunctionType( - self._llvm.INT32_TYPE, [ir.PointerType(self._llvm.INT8_TYPE)] - ) - puts_fn = ir.Function(self._llvm.module, puts_ty, name="puts") - - self._llvm.ir_builder.call(puts_fn, [ptr]) - - self.result_stack.append(ir.Constant(self._llvm.INT32_TYPE, 0)) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.Identifier) -> None: - """Translate ASTx Identifier to LLVM-IR.""" - expr_var = self.named_values.get(node.name) - - if not expr_var: - raise Exception(f"Unknown variable name: {node.name}") - - result = self._llvm.ir_builder.load(expr_var, node.name) - self.result_stack.append(result) - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.VariableDeclaration) -> None: - """Translate ASTx Identifier to LLVM-IR.""" - if self.named_values.get(node.name): - raise Exception(f"Identifier already declared: {node.name}") - - type_str = node.type_.__class__.__name__.lower() - - # Emit the initializer - if node.value is not None: - self.visit(node.value) - init_val = self.result_stack.pop() - if init_val is None: - raise Exception("Initializer code generation failed.") - - if type_str == "string": - alloca = self.create_entry_block_alloca( - node.name, "stringascii" - ) - self._llvm.ir_builder.store(init_val, alloca) - else: - alloca = self.create_entry_block_alloca(node.name, type_str) - self._llvm.ir_builder.store(init_val, alloca) - - else: - if type_str == "string": - # For strings, create empty string - empty_str_type = ir.ArrayType(self._llvm.INT8_TYPE, 1) - empty_str_global = ir.GlobalVariable( - self._llvm.module, - empty_str_type, - name=f"empty_str_{node.name}", - ) - empty_str_global.linkage = "internal" - empty_str_global.global_constant = True - empty_str_global.initializer = ir.Constant( - empty_str_type, bytearray(b"\0") - ) - - init_val = self._llvm.ir_builder.gep( - empty_str_global, - [ - ir.Constant(ir.IntType(32), 0), - ir.Constant(ir.IntType(32), 0), - ], - inbounds=True, - ) - alloca = self.create_entry_block_alloca( - node.name, "stringascii" - ) - - elif "float" in type_str: - init_val = ir.Constant(self._llvm.get_data_type(type_str), 0.0) - alloca = self.create_entry_block_alloca(node.name, type_str) - - else: - # If not specified, use 0 as the initializer. - init_val = ir.Constant(self._llvm.get_data_type(type_str), 0) - alloca = self.create_entry_block_alloca(node.name, type_str) - - # Store the initial value. - self._llvm.ir_builder.store(init_val, alloca) - - # Remember this binding. - self.named_values[node.name] = alloca - - @dispatch # type: ignore[no-redef] - def visit(self, node: astx.LiteralInt16) -> None: - """Translate ASTx LiteralInt16 to LLVM-IR.""" - result = ir.Constant(self._llvm.INT16_TYPE, node.value) - self.result_stack.append(result) + self.visit(astx.LiteralUTF8String(value=expr.value)) @public @@ -2376,34 +964,21 @@ class LLVMLiteIR(Builder): """LLVM-IR transpiler and compiler.""" def __init__(self) -> None: - """Initialize LLVMIR.""" super().__init__() - self.translator: LLVMLiteIRVisitor = LLVMLiteIRVisitor() + self.translator = LLVMLiteIRVisitor() def build(self, node: astx.AST, output_file: str) -> None: - """Transpile the ASTx to LLVM-IR and build it to an executable file.""" self.translator = LLVMLiteIRVisitor() - result = self.translator.translate(node) - - result_mod = llvm.parse_assembly(result) - result_object = self.translator.target_machine.emit_object(result_mod) + ir_text = self.translator.translate(node) - with tempfile.NamedTemporaryFile(suffix="", delete=True) as temp_file: - self.tmp_path = temp_file.name + mod = llvm.parse_assembly(ir_text) + obj = self.translator.target_machine.emit_object(mod) - file_path_o = f"{self.tmp_path}.o" + with tempfile.NamedTemporaryFile(delete=True) as f: + obj_path = f.name + ".o" - with open(file_path_o, "wb") as f: - f.write(result_object) + with open(obj_path, "wb") as f: + f.write(obj) - self.output_file = output_file - - # fix xh typing - clang: Callable[..., Any] = xh.clang - - clang( - file_path_o, - "-o", - self.output_file, - ) - os.chmod(self.output_file, 0o755) + xh.clang(obj_path, "-o", output_file) + os.chmod(output_file, 0o755)