Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 147 additions & 54 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from irx.builders.base import Builder, BuilderVisitor
from irx.tools.typing import typechecked

FLOAT16_BITS = 16
FLOAT32_BITS = 32
FLOAT64_BITS = 64
FLOAT128_BITS = 128


def is_fp_type(t: "ir.Type") -> bool:
"""Return True if t is any floating-point LLVM type."""
Expand Down Expand Up @@ -452,6 +457,142 @@ def _apply_fast_math(self, inst: ir.Instruction) -> None:
except (AttributeError, TypeError):
return

def _is_numeric_value(self, value: ir.Value) -> bool:
"""Return True if value represents an int/float scalar or vector."""
if is_vector(value):
elem_ty = value.type.element
return isinstance(elem_ty, ir.IntType) or is_fp_type(elem_ty)
base_ty = value.type
return isinstance(base_ty, ir.IntType) or is_fp_type(base_ty)

def _unify_numeric_operands(
self, lhs: ir.Value, rhs: ir.Value
) -> tuple[ir.Value, ir.Value]:
"""Ensure numeric operands share shape and scalar type."""
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _unify_numeric_operands method would benefit from more detailed documentation. The current docstring "Ensure numeric operands share shape and scalar type" is minimal. Consider documenting: (1) the promotion rules (e.g., int promotes to float, narrower types promote to wider), (2) parameter types and constraints, (3) return value guarantees, (4) what exceptions can be raised, and (5) examples of transformations. This is a critical function for type safety and clear documentation would help maintainers understand the promotion semantics.

Suggested change
"""Ensure numeric operands share shape and scalar type."""
"""
Normalize two numeric LLVM values to a common scalar type and shape.
This helper is used before emitting arithmetic or comparison
instructions so that both operands are type-compatible. It supports
scalar and vector integers / floating-point values and performs both
scalar type promotion and optional scalar-to-vector splatting.
Promotion rules
---------------
* Shape:
- If both operands are vectors, they must have the same number of
lanes; otherwise an Exception is raised.
- If exactly one operand is a vector, its lane count is used as the
target shape and the scalar operand is splatted to a vector of the
same lane count after type promotion.
- If both operands are scalars, the result operands remain scalars.
* Scalar type:
- If either operand has a floating-point scalar type, both operands
are promoted to a common floating-point type selected via
``self._select_float_type`` from the floating-point candidates.
- If both operands have integer scalar types, both are promoted to an
integer type with ``width = max(lhs.width, rhs.width)`` (at least
1 bit), preserving signedness semantics as implemented by
``_cast_value_to_type``.
Parameters
----------
lhs : llvmlite.ir.Value
Left-hand numeric operand. May be a scalar or vector of integer or
floating-point type.
rhs : llvmlite.ir.Value
Right-hand numeric operand. May be a scalar or vector of integer or
floating-point type.
Returns
-------
(llvmlite.ir.Value, llvmlite.ir.Value)
A pair ``(lhs', rhs')`` where:
* ``lhs'.type`` and ``rhs'.type`` have the same scalar element type.
* If either operand is a vector, both results are vectors with the
same lane count.
Raises
------
Exception
If both operands are vectors and their lane counts (``.count``)
differ.
Any exception raised by ``_cast_value_to_type`` may also propagate if
the operands cannot be safely cast to the selected target type.
Examples
--------
* ``i32 + i64`` -> both operands promoted to ``i64``.
* ``float + i32`` -> both operands promoted to ``float``.
* ``<4 x i16> + i32`` -> scalar ``i32`` cast to ``i32`` then splatted
to ``<4 x i32>`` to match the vector operand.
"""

Copilot uses AI. Check for mistakes.
lhs_is_vec = is_vector(lhs)
rhs_is_vec = is_vector(rhs)

if lhs_is_vec and rhs_is_vec and lhs.type.count != rhs.type.count:
raise Exception(
f"Vector size mismatch: {lhs.type.count} vs {rhs.type.count}"
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message could be more informative by including the operation context. Instead of just "Vector size mismatch: X vs Y", consider including information about what operation was being attempted (e.g., "Binary operation '+' requires matching vector sizes, but got X vs Y"). This would help developers debug issues more quickly.

Suggested change
f"Vector size mismatch: {lhs.type.count} vs {rhs.type.count}"
"Numeric operation requires matching vector sizes, "
f"but got {lhs.type} (size {lhs.type.count}) vs "
f"{rhs.type} (size {rhs.type.count})"

Copilot uses AI. Check for mistakes.
)

target_lanes = None
if lhs_is_vec:
target_lanes = lhs.type.count
elif rhs_is_vec:
target_lanes = rhs.type.count

lhs_base_ty = lhs.type.element if lhs_is_vec else lhs.type
rhs_base_ty = rhs.type.element if rhs_is_vec else rhs.type

lhs_is_float = is_fp_type(lhs_base_ty)
rhs_is_float = is_fp_type(rhs_base_ty)

if lhs_is_float or rhs_is_float:
float_candidates = [
ty for ty in (lhs_base_ty, rhs_base_ty) if is_fp_type(ty)
]
target_scalar_ty = self._select_float_type(float_candidates)
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When mixing integer and floating-point operands, the integer width is not considered when selecting the target float type. For example, an int64 combined with a float32 will promote both to float32, which can cause precision loss since float32 cannot accurately represent all int64 values. Consider promoting to at least double (float64) when the integer operand has width > 32 bits, or document this behavior if the precision loss is acceptable for your use case.

Suggested change
target_scalar_ty = self._select_float_type(float_candidates)
target_scalar_ty = self._select_float_type(float_candidates)
# If we are mixing an integer with a floating-point value, ensure that
# wide integers (> 32 bits) are promoted to at least double precision
# to avoid excessive precision loss when the selected float type is
# narrower than 64 bits.
if lhs_is_float != rhs_is_float:
int_base_ty = lhs_base_ty if not lhs_is_float else rhs_base_ty
int_width = getattr(int_base_ty, "width", 0)
# Determine the bit-width of the selected floating-point type.
float_bits = 0
if isinstance(target_scalar_ty, HalfType):
float_bits = FLOAT16_BITS
elif isinstance(target_scalar_ty, FloatType):
float_bits = FLOAT32_BITS
elif isinstance(target_scalar_ty, DoubleType):
float_bits = FLOAT64_BITS
elif FP128Type is not None and isinstance(target_scalar_ty, FP128Type):
float_bits = FLOAT128_BITS
# Upgrade to double precision when combining a wide integer with
# a float type that is narrower than 64 bits.
if int_width > 32 and float_bits and float_bits < FLOAT64_BITS:
target_scalar_ty = DoubleType()

Copilot uses AI. Check for mistakes.
else:
lhs_width = getattr(lhs_base_ty, "width", 0)
rhs_width = getattr(rhs_base_ty, "width", 0)
target_scalar_ty = ir.IntType(max(lhs_width, rhs_width, 1))

lhs = self._cast_value_to_type(lhs, target_scalar_ty)
rhs = self._cast_value_to_type(rhs, target_scalar_ty)

if target_lanes:
vec_ty = ir.VectorType(target_scalar_ty, target_lanes)
if not is_vector(lhs):
lhs = splat_scalar(self._llvm.ir_builder, lhs, vec_ty)
if not is_vector(rhs):
rhs = splat_scalar(self._llvm.ir_builder, rhs, vec_ty)

return lhs, rhs
Comment on lines +468 to +512
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new unification logic changes the type promotion behavior compared to the removed code. Previously, when combining a float vector with a double scalar, the scalar would be truncated (fptrunc) to match the vector's element type. Now, both are promoted to the wider type (double). This is generally better for precision, but represents a behavior change that could affect existing code relying on the old behavior. Ensure this is intentional and documented, especially since it could impact numerical precision in existing computations.

Copilot uses AI. Check for mistakes.

def _select_float_type(self, candidates: list[ir.Type]) -> ir.Type:
"""Choose the widest float type from provided candidates."""
if not candidates:
return self._llvm.FLOAT_TYPE

width = max(self._float_bit_width(ty) for ty in candidates)
return self._float_type_from_width(width)

def _float_type_from_width(self, width: int) -> ir.Type:
if width <= FLOAT16_BITS and hasattr(self._llvm, "FLOAT16_TYPE"):
return self._llvm.FLOAT16_TYPE
if width <= FLOAT32_BITS:
return self._llvm.FLOAT_TYPE
if width <= FLOAT64_BITS:
return self._llvm.DOUBLE_TYPE
if FP128Type is not None and width >= FLOAT128_BITS:
return FP128Type()
return self._llvm.FLOAT_TYPE

def _float_bit_width(self, ty: ir.Type) -> int:
if isinstance(ty, DoubleType):
return FLOAT64_BITS
if isinstance(ty, FloatType):
return FLOAT32_BITS
if isinstance(ty, HalfType):
return FLOAT16_BITS
if FP128Type is not None and isinstance(ty, FP128Type):
return FLOAT128_BITS
return getattr(ty, "width", 0)

def _cast_value_to_type(
self, value: ir.Value, target_scalar_ty: ir.Type
) -> ir.Value:
"""Cast scalars or vectors to the target scalar type."""
builder = self._llvm.ir_builder
value_is_vec = is_vector(value)
if value_is_vec:
lanes = value.type.count
current_scalar_ty = value.type.element
target_ty = ir.VectorType(target_scalar_ty, lanes)
else:
lanes = None
Comment on lines +551 to +555
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable lanes is not used.

Suggested change
lanes = value.type.count
current_scalar_ty = value.type.element
target_ty = ir.VectorType(target_scalar_ty, lanes)
else:
lanes = None
current_scalar_ty = value.type.element
target_ty = ir.VectorType(target_scalar_ty, value.type.count)
else:

Copilot uses AI. Check for mistakes.
current_scalar_ty = value.type
target_ty = target_scalar_ty

if current_scalar_ty == target_scalar_ty and value.type == target_ty:
return value

current_is_float = is_fp_type(current_scalar_ty)
target_is_float = is_fp_type(target_scalar_ty)

if target_is_float:
if current_is_float:
current_bits = self._float_bit_width(current_scalar_ty)
target_bits = self._float_bit_width(target_scalar_ty)
if current_bits == target_bits:
if value.type != target_ty:
return builder.bitcast(value, target_ty)
return value
if current_bits < target_bits:
return builder.fpext(value, target_ty, "fpext")
return builder.fptrunc(value, target_ty, "fptrunc")
return builder.sitofp(value, target_ty, "sitofp")

if current_is_float:
raise Exception(
"Cannot implicitly convert floating-point to integer"
)

current_width = getattr(current_scalar_ty, "width", 0)
target_width = getattr(target_scalar_ty, "width", 0)

if current_width == target_width:
if value.type != target_ty:
return builder.bitcast(value, target_ty)
return value

if current_width < target_width:
return builder.sext(value, target_ty, "sext")

return builder.trunc(value, target_ty, "trunc")

@dispatch.abstract
def visit(self, node: astx.AST) -> None:
"""Translate an ASTx expression."""
Expand Down Expand Up @@ -551,60 +692,12 @@ def visit(self, node: astx.BinaryOp) -> None:
if not llvm_lhs or not llvm_rhs:
raise Exception("codegen: Invalid lhs/rhs")

# Scalar-vector promotion: one vector + matching scalar -> splat scalar
lhs_is_vec = is_vector(llvm_lhs)
rhs_is_vec = is_vector(llvm_rhs)
if lhs_is_vec and not rhs_is_vec:
elem_ty = llvm_lhs.type.element
if llvm_rhs.type == elem_ty:
llvm_rhs = splat_scalar(
self._llvm.ir_builder, llvm_rhs, llvm_lhs.type
)
elif is_fp_type(elem_ty) and is_fp_type(llvm_rhs.type):
if isinstance(elem_ty, FloatType) and isinstance(
llvm_rhs.type, DoubleType
):
llvm_rhs = self._llvm.ir_builder.fptrunc(
llvm_rhs, elem_ty, "vec_promote_scalar"
)
llvm_rhs = splat_scalar(
self._llvm.ir_builder, llvm_rhs, llvm_lhs.type
)
elif isinstance(elem_ty, DoubleType) and isinstance(
llvm_rhs.type, FloatType
):
llvm_rhs = self._llvm.ir_builder.fpext(
llvm_rhs, elem_ty, "vec_promote_scalar"
)
llvm_rhs = splat_scalar(
self._llvm.ir_builder, llvm_rhs, llvm_lhs.type
)
elif rhs_is_vec and not lhs_is_vec:
elem_ty = llvm_rhs.type.element
if llvm_lhs.type == elem_ty:
llvm_lhs = splat_scalar(
self._llvm.ir_builder, llvm_lhs, llvm_rhs.type
)
elif is_fp_type(elem_ty) and is_fp_type(llvm_lhs.type):
if isinstance(elem_ty, FloatType) and isinstance(
llvm_lhs.type, DoubleType
):
llvm_lhs = self._llvm.ir_builder.fptrunc(
llvm_lhs, elem_ty, "vec_promote_scalar"
)
llvm_lhs = splat_scalar(
self._llvm.ir_builder, llvm_lhs, llvm_rhs.type
)
elif isinstance(elem_ty, DoubleType) and isinstance(
llvm_lhs.type, FloatType
):
llvm_lhs = self._llvm.ir_builder.fpext(
llvm_lhs, elem_ty, "vec_promote_scalar"
)
llvm_lhs = splat_scalar(
self._llvm.ir_builder, llvm_lhs, llvm_rhs.type
)

if self._is_numeric_value(llvm_lhs) and self._is_numeric_value(
llvm_rhs
):
llvm_lhs, llvm_rhs = self._unify_numeric_operands(
llvm_lhs, llvm_rhs
)
# If both operands are LLVM vectors, handle as vector ops
if is_vector(llvm_lhs) and is_vector(llvm_rhs):
if llvm_lhs.type.count != llvm_rhs.type.count:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_llvmlite_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from irx.builders.llvmliteir import (
LLVMLiteIRVisitor,
emit_int_div,
is_fp_type,
splat_scalar,
)
from llvmlite import ir
Expand Down Expand Up @@ -97,6 +98,64 @@ def test_emit_int_div_signed_and_unsigned() -> None:
assert getattr(unsigned, "opname", "") == "udiv"


def test_unify_promotes_scalar_int_to_vector() -> None:
"""Scalar ints splat to match vector operands and widen width."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

vec_ty = ir.VectorType(ir.IntType(32), 2)
vec = ir.Constant(vec_ty, [ir.Constant(ir.IntType(32), 1)] * 2)
scalar = ir.Constant(ir.IntType(16), 5)

promoted_vec, promoted_scalar = visitor._unify_numeric_operands(
vec, scalar
)

assert isinstance(promoted_vec.type, ir.VectorType)
assert isinstance(promoted_scalar.type, ir.VectorType)
assert promoted_vec.type == vec_ty
assert promoted_scalar.type == vec_ty


def test_unify_vector_float_rank_matches_double() -> None:
"""Float vectors upgrade to match double scalars."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

float_vec_ty = ir.VectorType(visitor._llvm.FLOAT_TYPE, 2)
float_vec = ir.Constant(
float_vec_ty,
[
ir.Constant(visitor._llvm.FLOAT_TYPE, 1.0),
ir.Constant(visitor._llvm.FLOAT_TYPE, 2.0),
],
)
double_scalar = ir.Constant(visitor._llvm.DOUBLE_TYPE, 4.0)

widened_vec, widened_scalar = visitor._unify_numeric_operands(
float_vec, double_scalar
)

assert widened_vec.type.element == visitor._llvm.DOUBLE_TYPE
assert widened_scalar.type.element == visitor._llvm.DOUBLE_TYPE


def test_unify_int_and_float_scalars_returns_float() -> None:
"""Scalar int + float promotes to float for both operands."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

int_scalar = ir.Constant(visitor._llvm.INT32_TYPE, 7)
float_scalar = ir.Constant(visitor._llvm.FLOAT_TYPE, 1.25)

widened_int, widened_float = visitor._unify_numeric_operands(
int_scalar, float_scalar
)

assert is_fp_type(widened_int.type)
assert widened_float.type == visitor._llvm.FLOAT_TYPE
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing blank line before function definition. According to PEP 8, there should be two blank lines before top-level function definitions to maintain consistency with the rest of the file.

Suggested change
assert widened_float.type == visitor._llvm.FLOAT_TYPE
assert widened_float.type == visitor._llvm.FLOAT_TYPE

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +156
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test coverage is missing several important edge cases for _unify_numeric_operands: (1) two vectors with mismatched element types (e.g., int32 vector vs float vector), (2) truncation scenarios where a wider type needs to be narrowed to match another operand, (3) FP128 type handling if available, (4) error case where vectors have different sizes, and (5) scalar-to-scalar integer promotion with different widths. Consider adding tests for these scenarios to ensure the unification logic handles all cases correctly.

Copilot uses AI. Check for mistakes.


def test_set_fast_math_marks_float_ops() -> None:
"""set_fast_math should add fast flag to floating instructions."""
visitor = LLVMLiteIRVisitor()
Expand Down
Loading