diff --git a/src/irx/builders/llvmliteir.py b/src/irx/builders/llvmliteir.py index e2fe53b..ac3aba3 100644 --- a/src/irx/builders/llvmliteir.py +++ b/src/irx/builders/llvmliteir.py @@ -28,6 +28,11 @@ from irx.builders.base import Builder, BuilderVisitor from irx.tools.typing import typechecked +FLOAT16_BITS = 16 +FLOAT32_BITS = 32 +FLOAT64_BITS = 64 +FLOAT128_BITS = 128 + def is_fp_type(t: "ir.Type") -> bool: """Return True if t is any floating-point LLVM type.""" @@ -452,6 +457,142 @@ def _apply_fast_math(self, inst: ir.Instruction) -> None: except (AttributeError, TypeError): return + def _is_numeric_value(self, value: ir.Value) -> bool: + """Return True if value represents an int/float scalar or vector.""" + if is_vector(value): + elem_ty = value.type.element + return isinstance(elem_ty, ir.IntType) or is_fp_type(elem_ty) + base_ty = value.type + return isinstance(base_ty, ir.IntType) or is_fp_type(base_ty) + + def _unify_numeric_operands( + self, lhs: ir.Value, rhs: ir.Value + ) -> tuple[ir.Value, ir.Value]: + """Ensure numeric operands share shape and scalar type.""" + lhs_is_vec = is_vector(lhs) + rhs_is_vec = is_vector(rhs) + + if lhs_is_vec and rhs_is_vec and lhs.type.count != rhs.type.count: + raise Exception( + f"Vector size mismatch: {lhs.type.count} vs {rhs.type.count}" + ) + + target_lanes = None + if lhs_is_vec: + target_lanes = lhs.type.count + elif rhs_is_vec: + target_lanes = rhs.type.count + + lhs_base_ty = lhs.type.element if lhs_is_vec else lhs.type + rhs_base_ty = rhs.type.element if rhs_is_vec else rhs.type + + lhs_is_float = is_fp_type(lhs_base_ty) + rhs_is_float = is_fp_type(rhs_base_ty) + + if lhs_is_float or rhs_is_float: + float_candidates = [ + ty for ty in (lhs_base_ty, rhs_base_ty) if is_fp_type(ty) + ] + target_scalar_ty = self._select_float_type(float_candidates) + else: + lhs_width = getattr(lhs_base_ty, "width", 0) + rhs_width = getattr(rhs_base_ty, "width", 0) + target_scalar_ty = ir.IntType(max(lhs_width, rhs_width, 1)) + + lhs = self._cast_value_to_type(lhs, target_scalar_ty) + rhs = self._cast_value_to_type(rhs, target_scalar_ty) + + if target_lanes: + vec_ty = ir.VectorType(target_scalar_ty, target_lanes) + if not is_vector(lhs): + lhs = splat_scalar(self._llvm.ir_builder, lhs, vec_ty) + if not is_vector(rhs): + rhs = splat_scalar(self._llvm.ir_builder, rhs, vec_ty) + + return lhs, rhs + + def _select_float_type(self, candidates: list[ir.Type]) -> ir.Type: + """Choose the widest float type from provided candidates.""" + if not candidates: + return self._llvm.FLOAT_TYPE + + width = max(self._float_bit_width(ty) for ty in candidates) + return self._float_type_from_width(width) + + def _float_type_from_width(self, width: int) -> ir.Type: + if width <= FLOAT16_BITS and hasattr(self._llvm, "FLOAT16_TYPE"): + return self._llvm.FLOAT16_TYPE + if width <= FLOAT32_BITS: + return self._llvm.FLOAT_TYPE + if width <= FLOAT64_BITS: + return self._llvm.DOUBLE_TYPE + if FP128Type is not None and width >= FLOAT128_BITS: + return FP128Type() + return self._llvm.FLOAT_TYPE + + def _float_bit_width(self, ty: ir.Type) -> int: + if isinstance(ty, DoubleType): + return FLOAT64_BITS + if isinstance(ty, FloatType): + return FLOAT32_BITS + if isinstance(ty, HalfType): + return FLOAT16_BITS + if FP128Type is not None and isinstance(ty, FP128Type): + return FLOAT128_BITS + return getattr(ty, "width", 0) + + def _cast_value_to_type( + self, value: ir.Value, target_scalar_ty: ir.Type + ) -> ir.Value: + """Cast scalars or vectors to the target scalar type.""" + builder = self._llvm.ir_builder + value_is_vec = is_vector(value) + if value_is_vec: + lanes = value.type.count + current_scalar_ty = value.type.element + target_ty = ir.VectorType(target_scalar_ty, lanes) + else: + lanes = None + current_scalar_ty = value.type + target_ty = target_scalar_ty + + if current_scalar_ty == target_scalar_ty and value.type == target_ty: + return value + + current_is_float = is_fp_type(current_scalar_ty) + target_is_float = is_fp_type(target_scalar_ty) + + if target_is_float: + if current_is_float: + current_bits = self._float_bit_width(current_scalar_ty) + target_bits = self._float_bit_width(target_scalar_ty) + if current_bits == target_bits: + if value.type != target_ty: + return builder.bitcast(value, target_ty) + return value + if current_bits < target_bits: + return builder.fpext(value, target_ty, "fpext") + return builder.fptrunc(value, target_ty, "fptrunc") + return builder.sitofp(value, target_ty, "sitofp") + + if current_is_float: + raise Exception( + "Cannot implicitly convert floating-point to integer" + ) + + current_width = getattr(current_scalar_ty, "width", 0) + target_width = getattr(target_scalar_ty, "width", 0) + + if current_width == target_width: + if value.type != target_ty: + return builder.bitcast(value, target_ty) + return value + + if current_width < target_width: + return builder.sext(value, target_ty, "sext") + + return builder.trunc(value, target_ty, "trunc") + @dispatch.abstract def visit(self, node: astx.AST) -> None: """Translate an ASTx expression.""" @@ -551,60 +692,12 @@ def visit(self, node: astx.BinaryOp) -> None: if not llvm_lhs or not llvm_rhs: raise Exception("codegen: Invalid lhs/rhs") - # Scalar-vector promotion: one vector + matching scalar -> splat scalar - lhs_is_vec = is_vector(llvm_lhs) - rhs_is_vec = is_vector(llvm_rhs) - if lhs_is_vec and not rhs_is_vec: - elem_ty = llvm_lhs.type.element - if llvm_rhs.type == elem_ty: - llvm_rhs = splat_scalar( - self._llvm.ir_builder, llvm_rhs, llvm_lhs.type - ) - elif is_fp_type(elem_ty) and is_fp_type(llvm_rhs.type): - if isinstance(elem_ty, FloatType) and isinstance( - llvm_rhs.type, DoubleType - ): - llvm_rhs = self._llvm.ir_builder.fptrunc( - llvm_rhs, elem_ty, "vec_promote_scalar" - ) - llvm_rhs = splat_scalar( - self._llvm.ir_builder, llvm_rhs, llvm_lhs.type - ) - elif isinstance(elem_ty, DoubleType) and isinstance( - llvm_rhs.type, FloatType - ): - llvm_rhs = self._llvm.ir_builder.fpext( - llvm_rhs, elem_ty, "vec_promote_scalar" - ) - llvm_rhs = splat_scalar( - self._llvm.ir_builder, llvm_rhs, llvm_lhs.type - ) - elif rhs_is_vec and not lhs_is_vec: - elem_ty = llvm_rhs.type.element - if llvm_lhs.type == elem_ty: - llvm_lhs = splat_scalar( - self._llvm.ir_builder, llvm_lhs, llvm_rhs.type - ) - elif is_fp_type(elem_ty) and is_fp_type(llvm_lhs.type): - if isinstance(elem_ty, FloatType) and isinstance( - llvm_lhs.type, DoubleType - ): - llvm_lhs = self._llvm.ir_builder.fptrunc( - llvm_lhs, elem_ty, "vec_promote_scalar" - ) - llvm_lhs = splat_scalar( - self._llvm.ir_builder, llvm_lhs, llvm_rhs.type - ) - elif isinstance(elem_ty, DoubleType) and isinstance( - llvm_lhs.type, FloatType - ): - llvm_lhs = self._llvm.ir_builder.fpext( - llvm_lhs, elem_ty, "vec_promote_scalar" - ) - llvm_lhs = splat_scalar( - self._llvm.ir_builder, llvm_lhs, llvm_rhs.type - ) - + if self._is_numeric_value(llvm_lhs) and self._is_numeric_value( + llvm_rhs + ): + llvm_lhs, llvm_rhs = self._unify_numeric_operands( + llvm_lhs, llvm_rhs + ) # If both operands are LLVM vectors, handle as vector ops if is_vector(llvm_lhs) and is_vector(llvm_rhs): if llvm_lhs.type.count != llvm_rhs.type.count: diff --git a/tests/test_llvmlite_helpers.py b/tests/test_llvmlite_helpers.py index ae63dfd..bdb6476 100644 --- a/tests/test_llvmlite_helpers.py +++ b/tests/test_llvmlite_helpers.py @@ -7,6 +7,7 @@ from irx.builders.llvmliteir import ( LLVMLiteIRVisitor, emit_int_div, + is_fp_type, splat_scalar, ) from llvmlite import ir @@ -97,6 +98,64 @@ def test_emit_int_div_signed_and_unsigned() -> None: assert getattr(unsigned, "opname", "") == "udiv" +def test_unify_promotes_scalar_int_to_vector() -> None: + """Scalar ints splat to match vector operands and widen width.""" + visitor = LLVMLiteIRVisitor() + _prime_builder(visitor) + + vec_ty = ir.VectorType(ir.IntType(32), 2) + vec = ir.Constant(vec_ty, [ir.Constant(ir.IntType(32), 1)] * 2) + scalar = ir.Constant(ir.IntType(16), 5) + + promoted_vec, promoted_scalar = visitor._unify_numeric_operands( + vec, scalar + ) + + assert isinstance(promoted_vec.type, ir.VectorType) + assert isinstance(promoted_scalar.type, ir.VectorType) + assert promoted_vec.type == vec_ty + assert promoted_scalar.type == vec_ty + + +def test_unify_vector_float_rank_matches_double() -> None: + """Float vectors upgrade to match double scalars.""" + visitor = LLVMLiteIRVisitor() + _prime_builder(visitor) + + float_vec_ty = ir.VectorType(visitor._llvm.FLOAT_TYPE, 2) + float_vec = ir.Constant( + float_vec_ty, + [ + ir.Constant(visitor._llvm.FLOAT_TYPE, 1.0), + ir.Constant(visitor._llvm.FLOAT_TYPE, 2.0), + ], + ) + double_scalar = ir.Constant(visitor._llvm.DOUBLE_TYPE, 4.0) + + widened_vec, widened_scalar = visitor._unify_numeric_operands( + float_vec, double_scalar + ) + + assert widened_vec.type.element == visitor._llvm.DOUBLE_TYPE + assert widened_scalar.type.element == visitor._llvm.DOUBLE_TYPE + + +def test_unify_int_and_float_scalars_returns_float() -> None: + """Scalar int + float promotes to float for both operands.""" + visitor = LLVMLiteIRVisitor() + _prime_builder(visitor) + + int_scalar = ir.Constant(visitor._llvm.INT32_TYPE, 7) + float_scalar = ir.Constant(visitor._llvm.FLOAT_TYPE, 1.25) + + widened_int, widened_float = visitor._unify_numeric_operands( + int_scalar, float_scalar + ) + + assert is_fp_type(widened_int.type) + assert widened_float.type == visitor._llvm.FLOAT_TYPE + + def test_set_fast_math_marks_float_ops() -> None: """set_fast_math should add fast flag to floating instructions.""" visitor = LLVMLiteIRVisitor()