Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More trace optimisations #1552

Merged
merged 9 commits into from
Jan 20, 2025
234 changes: 211 additions & 23 deletions ykrt/src/compile/jitc_yk/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl Opt {
)?;
self.m.replace(iidx, Inst::Const(cidx));
}
_ => todo!(),
_ => panic!(),
}
}
(Operand::Var(_), Operand::Var(_)) => (),
Expand All @@ -259,26 +259,18 @@ impl Opt {
self.an.op_map(&self.m, x.rhs(&self.m)),
) {
(Operand::Var(op_iidx), Operand::Const(op_cidx)) => {
match self.m.const_(op_cidx) {
Const::Int(_, 0) => {
// Replace `x >> 0` with `x`.
self.m.replace(iidx, Inst::Copy(op_iidx));
}
_ => {
// Canonicalise to (Var, Const).
self.m.replace(
iidx,
BinOpInst::new(
Operand::Var(op_iidx),
BinOp::LShr,
Operand::Const(op_cidx),
)
.into(),
);
}
if let Const::Int(_, 0) = self.m.const_(op_cidx) {
// Replace `x >> 0` with `x`.
self.m.replace(iidx, Inst::Copy(op_iidx));
}
}
(Operand::Const(op_cidx), Operand::Var(_)) => {
if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) {
// Replace `0 >> x` with `0`.
let new_cidx = self.m.insert_const_int(*tyidx, 0)?;
self.m.replace(iidx, Inst::Const(new_cidx));
}
}
(Operand::Const(_), Operand::Var(_)) => (),
(Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => {
match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) {
(Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => {
Expand All @@ -292,7 +284,7 @@ impl Opt {
)?;
self.m.replace(iidx, Inst::Const(cidx));
}
_ => todo!(),
_ => panic!(),
}
}
(Operand::Var(_), Operand::Var(_)) => (),
Expand Down Expand Up @@ -392,7 +384,42 @@ impl Opt {
)?;
self.m.replace(iidx, Inst::Const(cidx));
}
_ => todo!(),
_ => panic!(),
}
}
(Operand::Var(_), Operand::Var(_)) => (),
},
BinOp::Shl => match (
self.an.op_map(&self.m, x.lhs(&self.m)),
self.an.op_map(&self.m, x.rhs(&self.m)),
) {
(Operand::Var(op_iidx), Operand::Const(op_cidx)) => {
if let Const::Int(_, 0) = self.m.const_(op_cidx) {
// Replace `x << 0` with `x`.
self.m.replace(iidx, Inst::Copy(op_iidx));
}
}
(Operand::Const(op_cidx), Operand::Var(_)) => {
if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) {
// Replace `0 << x` with `0`.
let new_cidx = self.m.insert_const_int(*tyidx, 0)?;
self.m.replace(iidx, Inst::Const(new_cidx));
}
}
(Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => {
match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) {
(Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => {
debug_assert_eq!(lhs_tyidx, rhs_tyidx);
let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else {
panic!()
};
let cidx = self.m.insert_const_int(
*lhs_tyidx,
(lhs_v << rhs_v).truncate(*bits),
)?;
self.m.replace(iidx, Inst::Const(cidx));
}
_ => panic!(),
}
}
(Operand::Var(_), Operand::Var(_)) => (),
Expand Down Expand Up @@ -425,7 +452,57 @@ impl Opt {
}
(Operand::Const(_), Operand::Var(_)) | (Operand::Var(_), Operand::Var(_)) => (),
},
_ => (),
BinOp::Xor => match (
self.an.op_map(&self.m, x.lhs(&self.m)),
self.an.op_map(&self.m, x.rhs(&self.m)),
) {
(Operand::Const(op_cidx), Operand::Var(op_iidx))
| (Operand::Var(op_iidx), Operand::Const(op_cidx)) => {
match self.m.const_(op_cidx) {
Const::Int(_, 0) => {
// Replace `x ^ 0` with `x`.
self.m.replace(iidx, Inst::Copy(op_iidx));
}
_ => {
// Canonicalise to (Var, Const).
self.m.replace(
iidx,
BinOpInst::new(
Operand::Var(op_iidx),
BinOp::Xor,
Operand::Const(op_cidx),
)
.into(),
);
}
}
}
(Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => {
match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) {
(Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => {
debug_assert_eq!(lhs_tyidx, rhs_tyidx);
let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else {
panic!()
};
let cidx = self.m.insert_const_int(
*lhs_tyidx,
(lhs_v ^ rhs_v).truncate(*bits),
)?;
self.m.replace(iidx, Inst::Const(cidx));
}
_ => panic!(),
}
}
(Operand::Var(_), Operand::Var(_)) => (),
},
_ => {
if let (Operand::Const(_), Operand::Const(_)) = (
self.an.op_map(&self.m, x.lhs(&self.m)),
self.an.op_map(&self.m, x.rhs(&self.m)),
) {
todo!("{:?}", x.binop());
}
}
},
Inst::DynPtrAdd(x) => {
if let Operand::Const(cidx) = self.an.op_map(&self.m, x.num_elems(&self.m)) {
Expand Down Expand Up @@ -963,6 +1040,22 @@ mod test {
black_box %0
",
);

Module::assert_ir_transform_eq(
"
entry:
%0: i8 = param 0
%1: i8 = lshr 0i8, %0
black_box %1
",
|m| opt(m).unwrap(),
"
...
entry:
%0: i8 = param ...
black_box 0i8
",
);
}

#[test]
Expand All @@ -984,6 +1077,60 @@ mod test {
);
}

#[test]
fn opt_shl_zero() {
Module::assert_ir_transform_eq(
"
entry:
%0: i8 = param 0
%1: i8 = shl %0, 0i8
black_box %1
",
|m| opt(m).unwrap(),
"
...
entry:
%0: i8 = param ...
black_box %0
",
);

Module::assert_ir_transform_eq(
"
entry:
%0: i8 = param 0
%1: i8 = shl 0i8, %0
black_box %1
",
|m| opt(m).unwrap(),
"
...
entry:
%0: i8 = param ...
black_box 0i8
",
);
}

#[test]
fn opt_shl_const() {
Module::assert_ir_transform_eq(
"
entry:
%0: i8 = 2i8
%1: i8 = 1i8
%2: i8 = shl %0, %1
black_box %2
",
|m| opt(m).unwrap(),
"
...
entry:
black_box 4i8
",
);
}

#[test]
fn opt_or_zero() {
Module::assert_ir_transform_eq(
Expand Down Expand Up @@ -1012,7 +1159,7 @@ mod test {
"
entry:
%0: i8 = 2i8
%1: i8 = 1i8
%1: i8 = 3i8
ptersilie marked this conversation as resolved.
Show resolved Hide resolved
%2: i8 = or %0, %1
black_box %2
",
Expand Down Expand Up @@ -1137,6 +1284,47 @@ mod test {
);
}

#[test]
fn opt_xor_zero() {
Module::assert_ir_transform_eq(
"
entry:
%0: i8 = param 0
%1: i8 = xor %0, 0i8
%2: i8 = xor 0i8, %0
black_box %1
black_box %2
",
|m| opt(m).unwrap(),
"
...
entry:
%0: i8 = param ...
black_box %0
black_box %0
",
);
}

#[test]
fn opt_xor_const() {
Module::assert_ir_transform_eq(
"
entry:
%0: i8 = 2i8
%1: i8 = 3i8
%2: i8 = xor %0, %1
black_box %2
",
|m| opt(m).unwrap(),
"
...
entry:
black_box 1i8
",
);
}

#[test]
fn opt_icmp_const() {
Module::assert_ir_transform_eq(
Expand Down