Skip to content

Commit

Permalink
emit arith suffix (addi/addf) based on type
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderViand-Intel committed Jan 17, 2025
1 parent 9c10579 commit b3bf2d0
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions heir_py/mlir_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ def mlirType(numba_type):
return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">"
raise NotImplementedError("Unsupported type: " + str(numba_type))

def arithSuffix(numba_type):
if isinstance(numba_type, types.Integer):
return "i"
if isinstance(numba_type, types.Boolean):
return "i"
if isinstance(numba_type, types.Float):
return "f"
if isinstance(numba_type, types.Complex):
raise NotImplementedError("Complex numbers not supported in `arith` dialect")
if isinstance(numba_type, types.Array):
return arithSuffix(numba_type.dtype)
raise NotImplementedError("Unsupported type: " + str(numba_type))


class TextualMlirEmitter:
def __init__(self, ssa_ir, typemap, retty):
Expand Down Expand Up @@ -228,16 +241,18 @@ def emit_expr(self, expr):
def emit_binop(self, binop):
lhs_ssa = self.get_name(binop.lhs)
rhs_ssa = self.get_name(binop.rhs)
# This should be the same, otherwise MLIR will complain
suffix = arithSuffix(self.typemap.get(str(binop.lhs)))

match binop.fn:
case operator.lt:
return f"arith.cmpi slt, {lhs_ssa}, {rhs_ssa}"
return f"arith.cmp{suffix} slt, {lhs_ssa}, {rhs_ssa}"
case operator.add:
return f"arith.addi {lhs_ssa}, {rhs_ssa}"
return f"arith.add{suffix} {lhs_ssa}, {rhs_ssa}"
case operator.mul:
return f"arith.muli {lhs_ssa}, {rhs_ssa}"
return f"arith.mul{suffix} {lhs_ssa}, {rhs_ssa}"
case operator.sub:
return f"arith.subi {lhs_ssa}, {rhs_ssa}"
return f"arith.sub{suffix} {lhs_ssa}, {rhs_ssa}"

raise NotImplementedError("Unsupported binop: " + binop.fn.__name__)

Expand Down

0 comments on commit b3bf2d0

Please sign in to comment.