-
Notifications
You must be signed in to change notification settings - Fork 13
feat: unify numeric operand promotion #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,6 +28,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| from irx.builders.base import Builder, BuilderVisitor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from irx.tools.typing import typechecked | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FLOAT16_BITS = 16 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FLOAT32_BITS = 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FLOAT64_BITS = 64 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FLOAT128_BITS = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def is_fp_type(t: "ir.Type") -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return True if t is any floating-point LLVM type.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -452,6 +457,142 @@ def _apply_fast_math(self, inst: ir.Instruction) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| except (AttributeError, TypeError): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _is_numeric_value(self, value: ir.Value) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return True if value represents an int/float scalar or vector.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_vector(value): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elem_ty = value.type.element | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return isinstance(elem_ty, ir.IntType) or is_fp_type(elem_ty) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| base_ty = value.type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return isinstance(base_ty, ir.IntType) or is_fp_type(base_ty) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _unify_numeric_operands( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, lhs: ir.Value, rhs: ir.Value | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[ir.Value, ir.Value]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Ensure numeric operands share shape and scalar type.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lhs_is_vec = is_vector(lhs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rhs_is_vec = is_vector(rhs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lhs_is_vec and rhs_is_vec and lhs.type.count != rhs.type.count: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise Exception( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Vector size mismatch: {lhs.type.count} vs {rhs.type.count}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Vector size mismatch: {lhs.type.count} vs {rhs.type.count}" | |
| "Numeric operation requires matching vector sizes, " | |
| f"but got {lhs.type} (size {lhs.type.count}) vs " | |
| f"{rhs.type} (size {rhs.type.count})" |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When mixing integer and floating-point operands, the integer width is not considered when selecting the target float type. For example, an int64 combined with a float32 will promote both to float32, which can cause precision loss since float32 cannot accurately represent all int64 values. Consider promoting to at least double (float64) when the integer operand has width > 32 bits, or document this behavior if the precision loss is acceptable for your use case.
| target_scalar_ty = self._select_float_type(float_candidates) | |
| target_scalar_ty = self._select_float_type(float_candidates) | |
| # If we are mixing an integer with a floating-point value, ensure that | |
| # wide integers (> 32 bits) are promoted to at least double precision | |
| # to avoid excessive precision loss when the selected float type is | |
| # narrower than 64 bits. | |
| if lhs_is_float != rhs_is_float: | |
| int_base_ty = lhs_base_ty if not lhs_is_float else rhs_base_ty | |
| int_width = getattr(int_base_ty, "width", 0) | |
| # Determine the bit-width of the selected floating-point type. | |
| float_bits = 0 | |
| if isinstance(target_scalar_ty, HalfType): | |
| float_bits = FLOAT16_BITS | |
| elif isinstance(target_scalar_ty, FloatType): | |
| float_bits = FLOAT32_BITS | |
| elif isinstance(target_scalar_ty, DoubleType): | |
| float_bits = FLOAT64_BITS | |
| elif FP128Type is not None and isinstance(target_scalar_ty, FP128Type): | |
| float_bits = FLOAT128_BITS | |
| # Upgrade to double precision when combining a wide integer with | |
| # a float type that is narrower than 64 bits. | |
| if int_width > 32 and float_bits and float_bits < FLOAT64_BITS: | |
| target_scalar_ty = DoubleType() |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new unification logic changes the type promotion behavior compared to the removed code. Previously, when combining a float vector with a double scalar, the scalar would be truncated (fptrunc) to match the vector's element type. Now, both are promoted to the wider type (double). This is generally better for precision, but represents a behavior change that could affect existing code relying on the old behavior. Ensure this is intentional and documented, especially since it could impact numerical precision in existing computations.
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable lanes is not used.
| lanes = value.type.count | |
| current_scalar_ty = value.type.element | |
| target_ty = ir.VectorType(target_scalar_ty, lanes) | |
| else: | |
| lanes = None | |
| current_scalar_ty = value.type.element | |
| target_ty = ir.VectorType(target_scalar_ty, value.type.count) | |
| else: |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
| from irx.builders.llvmliteir import ( | ||||||
| LLVMLiteIRVisitor, | ||||||
| emit_int_div, | ||||||
| is_fp_type, | ||||||
| splat_scalar, | ||||||
| ) | ||||||
| from llvmlite import ir | ||||||
|
|
@@ -97,6 +98,64 @@ def test_emit_int_div_signed_and_unsigned() -> None: | |||||
| assert getattr(unsigned, "opname", "") == "udiv" | ||||||
|
|
||||||
|
|
||||||
| def test_unify_promotes_scalar_int_to_vector() -> None: | ||||||
| """Scalar ints splat to match vector operands and widen width.""" | ||||||
| visitor = LLVMLiteIRVisitor() | ||||||
| _prime_builder(visitor) | ||||||
|
|
||||||
| vec_ty = ir.VectorType(ir.IntType(32), 2) | ||||||
| vec = ir.Constant(vec_ty, [ir.Constant(ir.IntType(32), 1)] * 2) | ||||||
| scalar = ir.Constant(ir.IntType(16), 5) | ||||||
|
|
||||||
| promoted_vec, promoted_scalar = visitor._unify_numeric_operands( | ||||||
| vec, scalar | ||||||
| ) | ||||||
|
|
||||||
| assert isinstance(promoted_vec.type, ir.VectorType) | ||||||
| assert isinstance(promoted_scalar.type, ir.VectorType) | ||||||
| assert promoted_vec.type == vec_ty | ||||||
| assert promoted_scalar.type == vec_ty | ||||||
|
|
||||||
|
|
||||||
| def test_unify_vector_float_rank_matches_double() -> None: | ||||||
| """Float vectors upgrade to match double scalars.""" | ||||||
| visitor = LLVMLiteIRVisitor() | ||||||
| _prime_builder(visitor) | ||||||
|
|
||||||
| float_vec_ty = ir.VectorType(visitor._llvm.FLOAT_TYPE, 2) | ||||||
| float_vec = ir.Constant( | ||||||
| float_vec_ty, | ||||||
| [ | ||||||
| ir.Constant(visitor._llvm.FLOAT_TYPE, 1.0), | ||||||
| ir.Constant(visitor._llvm.FLOAT_TYPE, 2.0), | ||||||
| ], | ||||||
| ) | ||||||
| double_scalar = ir.Constant(visitor._llvm.DOUBLE_TYPE, 4.0) | ||||||
|
|
||||||
| widened_vec, widened_scalar = visitor._unify_numeric_operands( | ||||||
| float_vec, double_scalar | ||||||
| ) | ||||||
|
|
||||||
| assert widened_vec.type.element == visitor._llvm.DOUBLE_TYPE | ||||||
| assert widened_scalar.type.element == visitor._llvm.DOUBLE_TYPE | ||||||
|
|
||||||
|
|
||||||
| def test_unify_int_and_float_scalars_returns_float() -> None: | ||||||
| """Scalar int + float promotes to float for both operands.""" | ||||||
| visitor = LLVMLiteIRVisitor() | ||||||
| _prime_builder(visitor) | ||||||
|
|
||||||
| int_scalar = ir.Constant(visitor._llvm.INT32_TYPE, 7) | ||||||
| float_scalar = ir.Constant(visitor._llvm.FLOAT_TYPE, 1.25) | ||||||
|
|
||||||
| widened_int, widened_float = visitor._unify_numeric_operands( | ||||||
| int_scalar, float_scalar | ||||||
| ) | ||||||
|
|
||||||
| assert is_fp_type(widened_int.type) | ||||||
| assert widened_float.type == visitor._llvm.FLOAT_TYPE | ||||||
|
||||||
| assert widened_float.type == visitor._llvm.FLOAT_TYPE | |
| assert widened_float.type == visitor._llvm.FLOAT_TYPE |
Copilot
AI
Jan 19, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test coverage is missing several important edge cases for _unify_numeric_operands: (1) two vectors with mismatched element types (e.g., int32 vector vs float vector), (2) truncation scenarios where a wider type needs to be narrowed to match another operand, (3) FP128 type handling if available, (4) error case where vectors have different sizes, and (5) scalar-to-scalar integer promotion with different widths. Consider adding tests for these scenarios to ensure the unification logic handles all cases correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_unify_numeric_operandsmethod would benefit from more detailed documentation. The current docstring "Ensure numeric operands share shape and scalar type" is minimal. Consider documenting: (1) the promotion rules (e.g., int promotes to float, narrower types promote to wider), (2) parameter types and constraints, (3) return value guarantees, (4) what exceptions can be raised, and (5) examples of transformations. This is a critical function for type safety and clear documentation would help maintainers understand the promotion semantics.