Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Generate simpler code for Int and Long division and remainder. #114

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 136 additions & 74 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -977,11 +977,147 @@ private class WasmExpressionBuilder private (
IRTypes.LongType
}

def genThrowArithmeticException(): Unit = {
implicit val pos = binary.pos
val divisionByZeroEx = IRTrees.Throw(
IRTrees.New(
IRNames.ArithmeticExceptionClass,
IRTrees.MethodIdent(
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
),
List(IRTrees.StringLiteral("/ by zero "))
)
)
genThrow(divisionByZeroEx)
}

def genDivModByConstant[T](
isDiv: Boolean,
rhsValue: T,
const: T => WasmInstr,
sub: WasmInstr,
mainOp: WasmInstr
)(implicit num: Numeric[T]): IRTypes.Type = {
/* When we statically know the value of the rhs, we can avoid the
* dynamic tests for division by zero and overflow. This is quite
* common in practice.
*/

val tpe = binary.tpe

if (rhsValue == num.zero) {
genTree(binary.lhs, tpe)
fctx.markPosition(binary)
genThrowArithmeticException()
IRTypes.NothingType
} else if (isDiv && rhsValue == num.fromInt(-1)) {
/* MinValue / -1 overflows; it traps in Wasm but we need to wrap.
* We rewrite as `0 - lhs` so that we do not need any test.
*/
fctx.markPosition(binary)
instrs += const(num.zero)
genTree(binary.lhs, tpe)
fctx.markPosition(binary)
instrs += sub
tpe
} else {
genTree(binary.lhs, tpe)
fctx.markPosition(binary.rhs)
instrs += const(rhsValue)
fctx.markPosition(binary)
instrs += mainOp
tpe
}
}

def genDivMod[T](
isDiv: Boolean,
const: T => WasmInstr,
eqz: WasmInstr,
eq: WasmInstr,
sub: WasmInstr,
mainOp: WasmInstr
)(implicit num: Numeric[T]): IRTypes.Type = {
/* Here we perform the same steps as in the static case, but using
* value tests at run-time.
*/

val tpe = binary.tpe
val wasmTyp = TypeTransformer.transformType(tpe)(ctx)

val lhsLocal = fctx.addSyntheticLocal(wasmTyp)
val rhsLocal = fctx.addSyntheticLocal(wasmTyp)
genTree(binary.lhs, tpe)
instrs += LOCAL_SET(lhsLocal)
genTree(binary.rhs, tpe)
instrs += LOCAL_TEE(rhsLocal)

fctx.markPosition(binary)

instrs += eqz
fctx.ifThen() {
genThrowArithmeticException()
}
if (isDiv) {
// Handle the MinValue / -1 corner case
instrs += LOCAL_GET(rhsLocal)
instrs += const(num.fromInt(-1))
instrs += eq
fctx.ifThenElse(wasmTyp) {
// 0 - lhs
instrs += const(num.zero)
instrs += LOCAL_GET(lhsLocal)
instrs += sub
} {
// lhs / rhs
instrs += LOCAL_GET(lhsLocal)
instrs += LOCAL_GET(rhsLocal)
instrs += mainOp
}
} else {
// lhs % rhs
instrs += LOCAL_GET(lhsLocal)
instrs += LOCAL_GET(rhsLocal)
instrs += mainOp
}

tpe
}

binary.op match {
case BinaryOp.=== | BinaryOp.!== => genEq(binary)

case BinaryOp.String_+ => genStringConcat(binary)

case BinaryOp.Int_/ =>
binary.rhs match {
case IRTrees.IntLiteral(rhsValue) =>
genDivModByConstant(isDiv = true, rhsValue, I32_CONST(_), I32_SUB, I32_DIV_S)
case _ =>
genDivMod(isDiv = true, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_DIV_S)
}
case BinaryOp.Int_% =>
binary.rhs match {
case IRTrees.IntLiteral(rhsValue) =>
genDivModByConstant(isDiv = false, rhsValue, I32_CONST(_), I32_SUB, I32_REM_S)
case _ =>
genDivMod(isDiv = false, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_REM_S)
}
case BinaryOp.Long_/ =>
binary.rhs match {
case IRTrees.LongLiteral(rhsValue) =>
genDivModByConstant(isDiv = true, rhsValue, I64_CONST(_), I64_SUB, I64_DIV_S)
case _ =>
genDivMod(isDiv = true, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_DIV_S)
}
case BinaryOp.Long_% =>
binary.rhs match {
case IRTrees.LongLiteral(rhsValue) =>
genDivModByConstant(isDiv = false, rhsValue, I64_CONST(_), I64_SUB, I64_REM_S)
case _ =>
genDivMod(isDiv = false, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_REM_S)
}

case BinaryOp.Long_<< => genLongShiftOp(I64_SHL)
case BinaryOp.Long_>>> => genLongShiftOp(I64_SHR_U)
case BinaryOp.Long_>> => genLongShiftOp(I64_SHR_S)
Expand Down Expand Up @@ -1019,80 +1155,6 @@ private class WasmExpressionBuilder private (
instrs += CALL(WasmFunctionName.stringCharAt)
IRTypes.CharType

// Check division by zero
// (Int|Long).MinValue / -1 = (Int|Long).MinValue because of overflow
case BinaryOp.Int_/ | BinaryOp.Long_/ | BinaryOp.Int_% | BinaryOp.Long_% =>
implicit val noPos = Position.NoPosition
val divisionByZeroEx = IRTrees.Throw(
IRTrees.New(
IRNames.ArithmeticExceptionClass,
IRTrees.MethodIdent(
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
),
List(IRTrees.StringLiteral("/ by zero "))
)
)
val resType = TypeTransformer.transformType(binary.tpe)(ctx)

val lhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.lhs.tpe)(ctx))
val rhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.rhs.tpe)(ctx))
genTreeAuto(binary.lhs)
instrs += LOCAL_SET(lhs)
genTreeAuto(binary.rhs)
instrs += LOCAL_SET(rhs)

fctx.markPosition(binary)

fctx.block(resType) { done =>
fctx.block() { default =>
fctx.block() { divisionByZero =>
instrs += LOCAL_GET(rhs)
binary.op match {
case BinaryOp.Int_/ | BinaryOp.Int_% => instrs += I32_EQZ
case BinaryOp.Long_/ | BinaryOp.Long_% => instrs += I64_EQZ
}
instrs += BR_IF(divisionByZero)

// Check overflow for division
if (binary.op == BinaryOp.Int_/ || binary.op == BinaryOp.Long_/) {
fctx.block() { overflow =>
instrs += LOCAL_GET(rhs)
if (binary.op == BinaryOp.Int_/) instrs ++= List(I32_CONST(-1), I32_EQ)
else instrs ++= List(I64_CONST(-1), I64_EQ)
fctx.ifThen() { // if (rhs == -1)
instrs += LOCAL_GET(lhs)
if (binary.op == BinaryOp.Int_/)
instrs ++= List(I32_CONST(Int.MinValue), I32_EQ)
else instrs ++= List(I64_CONST(Long.MinValue), I64_EQ)
instrs += BR_IF(overflow)
}
instrs += BR(default)
}
// overflow
if (binary.op == BinaryOp.Int_/) instrs += I32_CONST(Int.MinValue)
else instrs += I64_CONST(Long.MinValue)
instrs += BR(done)
}

// remainder
instrs += BR(default)
}
// division by zero
genThrow(divisionByZeroEx)
}
// default
instrs += LOCAL_GET(lhs)
instrs += LOCAL_GET(rhs)
instrs +=
(binary.op match {
case BinaryOp.Int_/ => I32_DIV_S
case BinaryOp.Int_% => I32_REM_S
case BinaryOp.Long_/ => I64_DIV_S
case BinaryOp.Long_% => I64_REM_S
})
binary.tpe
}

case _ => genElementaryBinaryOp(binary)
}
}
Expand Down