diff --git a/heir_py/mlir_emitter.py b/heir_py/mlir_emitter.py index d828977a2..16c698aa5 100644 --- a/heir_py/mlir_emitter.py +++ b/heir_py/mlir_emitter.py @@ -24,6 +24,19 @@ def mlirType(numba_type): return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" raise NotImplementedError("Unsupported type: " + str(numba_type)) +def arithSuffix(numba_type): + if isinstance(numba_type, types.Integer): + return "i" + if isinstance(numba_type, types.Boolean): + return "i" + if isinstance(numba_type, types.Float): + return "f" + if isinstance(numba_type, types.Complex): + raise NotImplementedError("Complex numbers not supported in `arith` dialect") + if isinstance(numba_type, types.Array): + return arithSuffix(numba_type.dtype) + raise NotImplementedError("Unsupported type: " + str(numba_type)) + class TextualMlirEmitter: def __init__(self, ssa_ir, typemap, retty): @@ -228,16 +241,18 @@ def emit_expr(self, expr): def emit_binop(self, binop): lhs_ssa = self.get_name(binop.lhs) rhs_ssa = self.get_name(binop.rhs) + # This should be the same, otherwise MLIR will complain + suffix = arithSuffix(self.typemap.get(str(binop.lhs))) match binop.fn: case operator.lt: - return f"arith.cmpi slt, {lhs_ssa}, {rhs_ssa}" + return f"arith.cmp{suffix} slt, {lhs_ssa}, {rhs_ssa}" case operator.add: - return f"arith.addi {lhs_ssa}, {rhs_ssa}" + return f"arith.add{suffix} {lhs_ssa}, {rhs_ssa}" case operator.mul: - return f"arith.muli {lhs_ssa}, {rhs_ssa}" + return f"arith.mul{suffix} {lhs_ssa}, {rhs_ssa}" case operator.sub: - return f"arith.subi {lhs_ssa}, {rhs_ssa}" + return f"arith.sub{suffix} {lhs_ssa}, {rhs_ssa}" raise NotImplementedError("Unsupported binop: " + binop.fn.__name__)