Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Generate simpler code for Int and Long division and remainder.
Browse files Browse the repository at this point in the history
Notably, when the rhs is a constant, we can emit straightforward
code without any branches. This is worth it because the divisor of
integer divisions and remainders is often constant in practice.

To handle the overflow case of `MinValue / -1`, we also remove one
branch by turning `lhs / -1` into `0 - lhs`. The previous codegen
used `if (lhs == MinValue) MinValue else lhs / -1` in the same
situation.
  • Loading branch information
sjrd committed Apr 21, 2024
1 parent 7465e39 commit d045f60
Showing 1 changed file with 129 additions and 72 deletions.
201 changes: 129 additions & 72 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,140 @@ private class WasmExpressionBuilder private (
IRTypes.LongType
}

def genThrowArithmeticException(): Unit = {
implicit val pos = binary.pos
val divisionByZeroEx = IRTrees.Throw(
IRTrees.New(
IRNames.ArithmeticExceptionClass,
IRTrees.MethodIdent(
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
),
List(IRTrees.StringLiteral("/ by zero "))
)
)
genThrow(divisionByZeroEx)
}

def genDivModByConstant[T](
isDiv: Boolean,
rhsValue: T,
const: T => WasmInstr,
sub: WasmInstr,
mainOp: WasmInstr
)(implicit num: Numeric[T]): IRTypes.Type = {
/* When we statically know the value of the rhs, we can avoid the
* dynamic tests for division by zero and overflow. This is quite
* common in practice.
*/

val tpe = binary.tpe

if (rhsValue == num.zero) {
genTree(binary.lhs, tpe)
genThrowArithmeticException()
IRTypes.NothingType
} else if (isDiv && rhsValue == num.fromInt(-1)) {
/* MinValue / -1 overflows; it traps in Wasm but we need to wrap.
* We rewrite as `0 - lhs` so that we do not need any test.
*/
instrs += const(num.zero)
genTree(binary.lhs, tpe)
instrs += sub
tpe
} else {
genTree(binary.lhs, tpe)
instrs += const(rhsValue)
instrs += mainOp
tpe
}
}

def genDivMod[T](
isDiv: Boolean,
const: T => WasmInstr,
eqz: WasmInstr,
eq: WasmInstr,
sub: WasmInstr,
mainOp: WasmInstr
)(implicit num: Numeric[T]): IRTypes.Type = {
/* Here we perform the same steps as in the static case, but using
* value tests at run-time.
*/

val tpe = binary.tpe
val wasmTyp = TypeTransformer.transformType(tpe)(ctx)

val lhsLocal = fctx.addSyntheticLocal(wasmTyp)
val rhsLocal = fctx.addSyntheticLocal(wasmTyp)
genTree(binary.lhs, tpe)
instrs += LOCAL_SET(lhsLocal)
genTree(binary.rhs, tpe)
instrs += LOCAL_TEE(rhsLocal)

instrs += eqz
fctx.ifThen() {
genThrowArithmeticException()
}
if (isDiv) {
// Handle the MinValue / -1 corner case
instrs += LOCAL_GET(rhsLocal)
instrs += const(num.fromInt(-1))
instrs += eq
fctx.ifThenElse(wasmTyp) {
// 0 - lhs
instrs += const(num.zero)
instrs += LOCAL_GET(lhsLocal)
instrs += sub
} {
// lhs / rhs
instrs += LOCAL_GET(lhsLocal)
instrs += LOCAL_GET(rhsLocal)
instrs += mainOp
}
} else {
// lhs % rhs
instrs += LOCAL_GET(lhsLocal)
instrs += LOCAL_GET(rhsLocal)
instrs += mainOp
}

tpe
}

binary.op match {
case BinaryOp.=== | BinaryOp.!== => genEq(binary)

case BinaryOp.String_+ => genStringConcat(binary.lhs, binary.rhs)

case BinaryOp.Int_/ =>
binary.rhs match {
case IRTrees.IntLiteral(rhsValue) =>
genDivModByConstant(isDiv = true, rhsValue, I32_CONST(_), I32_SUB, I32_DIV_S)
case _ =>
genDivMod(isDiv = true, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_DIV_S)
}
case BinaryOp.Int_% =>
binary.rhs match {
case IRTrees.IntLiteral(rhsValue) =>
genDivModByConstant(isDiv = false, rhsValue, I32_CONST(_), I32_SUB, I32_REM_S)
case _ =>
genDivMod(isDiv = false, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_REM_S)
}
case BinaryOp.Long_/ =>
binary.rhs match {
case IRTrees.LongLiteral(rhsValue) =>
genDivModByConstant(isDiv = true, rhsValue, I64_CONST(_), I64_SUB, I64_DIV_S)
case _ =>
genDivMod(isDiv = true, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_DIV_S)
}
case BinaryOp.Long_% =>
binary.rhs match {
case IRTrees.LongLiteral(rhsValue) =>
genDivModByConstant(isDiv = false, rhsValue, I64_CONST(_), I64_SUB, I64_REM_S)
case _ =>
genDivMod(isDiv = false, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_REM_S)
}

case BinaryOp.Long_<< => genLongShiftOp(I64_SHL)
case BinaryOp.Long_>>> => genLongShiftOp(I64_SHR_U)
case BinaryOp.Long_>> => genLongShiftOp(I64_SHR_S)
Expand Down Expand Up @@ -1004,78 +1133,6 @@ private class WasmExpressionBuilder private (
instrs += CALL(WasmFunctionName.stringCharAt)
IRTypes.CharType

// Check division by zero
// (Int|Long).MinValue / -1 = (Int|Long).MinValue because of overflow
case BinaryOp.Int_/ | BinaryOp.Long_/ | BinaryOp.Int_% | BinaryOp.Long_% =>
implicit val noPos = Position.NoPosition
val divisionByZeroEx = IRTrees.Throw(
IRTrees.New(
IRNames.ArithmeticExceptionClass,
IRTrees.MethodIdent(
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
),
List(IRTrees.StringLiteral("/ by zero "))
)
)
val resType = TypeTransformer.transformType(binary.tpe)(ctx)

val lhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.lhs.tpe)(ctx))
val rhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.rhs.tpe)(ctx))
genTreeAuto(binary.lhs)
instrs += LOCAL_SET(lhs)
genTreeAuto(binary.rhs)
instrs += LOCAL_SET(rhs)

fctx.block(resType) { done =>
fctx.block() { default =>
fctx.block() { divisionByZero =>
instrs += LOCAL_GET(rhs)
binary.op match {
case BinaryOp.Int_/ | BinaryOp.Int_% => instrs += I32_EQZ
case BinaryOp.Long_/ | BinaryOp.Long_% => instrs += I64_EQZ
}
instrs += BR_IF(divisionByZero)

// Check overflow for division
if (binary.op == BinaryOp.Int_/ || binary.op == BinaryOp.Long_/) {
fctx.block() { overflow =>
instrs += LOCAL_GET(rhs)
if (binary.op == BinaryOp.Int_/) instrs ++= List(I32_CONST(-1), I32_EQ)
else instrs ++= List(I64_CONST(-1), I64_EQ)
fctx.ifThen() { // if (rhs == -1)
instrs += LOCAL_GET(lhs)
if (binary.op == BinaryOp.Int_/)
instrs ++= List(I32_CONST(Int.MinValue), I32_EQ)
else instrs ++= List(I64_CONST(Long.MinValue), I64_EQ)
instrs += BR_IF(overflow)
}
instrs += BR(default)
}
// overflow
if (binary.op == BinaryOp.Int_/) instrs += I32_CONST(Int.MinValue)
else instrs += I64_CONST(Long.MinValue)
instrs += BR(done)
}

// remainder
instrs += BR(default)
}
// division by zero
genThrow(divisionByZeroEx)
}
// default
instrs += LOCAL_GET(lhs)
instrs += LOCAL_GET(rhs)
instrs +=
(binary.op match {
case BinaryOp.Int_/ => I32_DIV_S
case BinaryOp.Int_% => I32_REM_S
case BinaryOp.Long_/ => I64_DIV_S
case BinaryOp.Long_% => I64_REM_S
})
binary.tpe
}

case _ => genElementaryBinaryOp(binary)
}
}
Expand Down

0 comments on commit d045f60

Please sign in to comment.