From b59fffccefe68d63a512c3356bc5064a290e5caf Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 12:43:23 +0000 Subject: [PATCH 1/9] The canonicalisation case can't possibly happen. This is clearly a copy and paste error -- though thankfully a harmless one! --- ykrt/src/compile/jitc_yk/opt/mod.rs | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index f930bfc92..ca0bf18af 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -259,23 +259,9 @@ impl Opt { self.an.op_map(&self.m, x.rhs(&self.m)), ) { (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x >> 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); - } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::LShr, - Operand::Const(op_cidx), - ) - .into(), - ); - } + if let Const::Int(_, 0) = self.m.const_(op_cidx) { + // Replace `x >> 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); } } (Operand::Const(_), Operand::Var(_)) => (), From 301c652ff366cc65ad8c94ca90ef8ab0ede21964 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 12:46:43 +0000 Subject: [PATCH 2/9] Constant fold `BinOp::Shl`. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 67 +++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index ca0bf18af..7d3c07feb 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -383,6 +383,35 @@ impl Opt { } (Operand::Var(_), Operand::Var(_)) => (), }, + BinOp::Shl => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + if let Const::Int(_, 0) = self.m.const_(op_cidx) { + // Replace `x << 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + } + (Operand::Const(_), Operand::Var(_)) => (), + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v << rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), + } + } + (Operand::Var(_), Operand::Var(_)) => (), + }, BinOp::Sub => match ( self.an.op_map(&self.m, x.lhs(&self.m)), self.an.op_map(&self.m, x.rhs(&self.m)), @@ -970,6 +999,44 @@ mod test { ); } + #[test] + fn opt_shl_zero() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = shl %0, 0i8 + black_box %1 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + black_box %0 + ", + ); + } + + #[test] + fn opt_shl_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = 2i8 + %1: i8 = 1i8 + %2: i8 = shl %0, %1 + black_box %2 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 4i8 + ", + ); + } + #[test] fn opt_or_zero() { Module::assert_ir_transform_eq( From eb6dbc1fc1c1d2b545984f2b31f8529bed113759 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 12:54:14 +0000 Subject: [PATCH 3/9] Optimise `0 << x` and `0 >> x` to `0`. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 48 +++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 7d3c07feb..992971dad 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -264,7 +264,13 @@ impl Opt { self.m.replace(iidx, Inst::Copy(op_iidx)); } } - (Operand::Const(_), Operand::Var(_)) => (), + (Operand::Const(op_cidx), Operand::Var(_)) => { + if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { + // Replace `0 >> x` with `0`. + let new_cidx = self.m.insert_const_int(*tyidx, 0)?; + self.m.replace(iidx, Inst::Const(new_cidx)); + } + } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { @@ -393,7 +399,13 @@ impl Opt { self.m.replace(iidx, Inst::Copy(op_iidx)); } } - (Operand::Const(_), Operand::Var(_)) => (), + (Operand::Const(op_cidx), Operand::Var(_)) => { + if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { + // Replace `0 << x` with `0`. + let new_cidx = self.m.insert_const_int(*tyidx, 0)?; + self.m.replace(iidx, Inst::Const(new_cidx)); + } + } (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { @@ -978,6 +990,22 @@ mod test { black_box %0 ", ); + + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = lshr 0i8, %0 + black_box %1 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + black_box 0i8 + ", + ); } #[test] @@ -1016,6 +1044,22 @@ mod test { black_box %0 ", ); + + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = shl 0i8, %0 + black_box %1 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + black_box 0i8 + ", + ); } #[test] From eec0fdada81bc5e41834c1fcf51c5a4ae96dedb3 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 12:54:34 +0000 Subject: [PATCH 4/9] Bork when we find obvious constant folding we haven't yet implemented. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 992971dad..6bf9fe464 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -452,7 +452,14 @@ impl Opt { } (Operand::Const(_), Operand::Var(_)) | (Operand::Var(_), Operand::Var(_)) => (), }, - _ => (), + _ => { + if let (Operand::Const(_), Operand::Const(_)) = ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + todo!("{:?}", x.binop()); + } + } }, Inst::DynPtrAdd(x) => { if let Operand::Const(cidx) = self.an.op_map(&self.m, x.num_elems(&self.m)) { From cdcd6fdfffb0ef077796a21312db5a0738a24d1e Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 12:59:30 +0000 Subject: [PATCH 5/9] If we see non-ints with bit ops (shift left etc.) we have malformed IR. These should thus be `panic`s -- we should never have anything to implement here. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 6bf9fe464..407af6e67 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -249,7 +249,7 @@ impl Opt { )?; self.m.replace(iidx, Inst::Const(cidx)); } - _ => todo!(), + _ => panic!(), } } (Operand::Var(_), Operand::Var(_)) => (), @@ -284,7 +284,7 @@ impl Opt { )?; self.m.replace(iidx, Inst::Const(cidx)); } - _ => todo!(), + _ => panic!(), } } (Operand::Var(_), Operand::Var(_)) => (), @@ -384,7 +384,7 @@ impl Opt { )?; self.m.replace(iidx, Inst::Const(cidx)); } - _ => todo!(), + _ => panic!(), } } (Operand::Var(_), Operand::Var(_)) => (), @@ -419,7 +419,7 @@ impl Opt { )?; self.m.replace(iidx, Inst::Const(cidx)); } - _ => todo!(), + _ => panic!(), } } (Operand::Var(_), Operand::Var(_)) => (), From 1d39af2e25737045dbb81dad04294b779c4590b8 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 13:11:25 +0000 Subject: [PATCH 6/9] Optimise `BinOp::Xor`. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 86 ++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 407af6e67..c471752bc 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -452,6 +452,49 @@ impl Opt { } (Operand::Const(_), Operand::Var(_)) | (Operand::Var(_), Operand::Var(_)) => (), }, + BinOp::Xor => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x ^ 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::Xor, + Operand::Const(op_cidx), + ) + .into(), + ); + } + } + } + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v ^ rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), + } + } + (Operand::Var(_), Operand::Var(_)) => (), + }, _ => { if let (Operand::Const(_), Operand::Const(_)) = ( self.an.op_map(&self.m, x.lhs(&self.m)), @@ -1116,7 +1159,7 @@ mod test { " entry: %0: i8 = 2i8 - %1: i8 = 1i8 + %1: i8 = 3i8 %2: i8 = or %0, %1 black_box %2 ", @@ -1241,6 +1284,47 @@ mod test { ); } + #[test] + fn opt_xor_zero() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = xor %0, 0i8 + %2: i8 = xor 0i8, %0 + black_box %1 + black_box %2 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + black_box %0 + black_box %0 + ", + ); + } + + #[test] + fn opt_xor_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = 2i8 + %1: i8 = 3i8 + %2: i8 = xor %0, %1 + black_box %2 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 1i8 + ", + ); + } + #[test] fn opt_icmp_const() { Module::assert_ir_transform_eq( From e890a7bd6ce021436dbd9c120afc85d7a7e4554a Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 13:26:34 +0000 Subject: [PATCH 7/9] Add canonicalisation tests. I noticed (because I made a mistake!) that we weren't testing the canonicalisation cases properly. This commit looks like it changes much more than it really does: basically we add a comment to the main optimiser code which causes lots of indenting and churn without any meaningful changes. The meat of the commit is the test cases at the end. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 672 +++++++++++++++------------- 1 file changed, 370 insertions(+), 302 deletions(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index c471752bc..c871063b4 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -167,343 +167,348 @@ impl Opt { Inst::Const(_) | Inst::Copy(_) | Inst::Tombstone | Inst::TraceHeaderStart => { unreachable!() } - Inst::BinOp(x) => match x.binop() { - BinOp::Add => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Const(op_cidx), Operand::Var(op_iidx)) - | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x + 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); - } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::Add, - Operand::Const(op_cidx), - ) - .into(), - ); - } - } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v.wrapping_add(*rhs_v)).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + Inst::BinOp(x) => { + // Don't forget to add canonicalisations to the `canonicalisation` test! + match x.binop() { + BinOp::Add => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x + 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::Add, + Operand::Const(op_cidx), + ) + .into(), + ); + } } - _ => todo!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::And => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Const(op_cidx), Operand::Var(op_iidx)) - | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x & 0` with `0`. - self.m.replace(iidx, Inst::Const(op_cidx)); - } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::And, - Operand::Const(op_cidx), - ) - .into(), - ); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v.wrapping_add(*rhs_v)).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v & rhs_v).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::And => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x & 0` with `0`. + self.m.replace(iidx, Inst::Const(op_cidx)); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::And, + Operand::Const(op_cidx), + ) + .into(), + ); + } } - _ => panic!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::LShr => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x >> 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); - } - } - (Operand::Const(op_cidx), Operand::Var(_)) => { - if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { - // Replace `0 >> x` with `0`. - let new_cidx = self.m.insert_const_int(*tyidx, 0)?; - self.m.replace(iidx, Inst::Const(new_cidx)); - } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v >> rhs_v).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v & rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), } - _ => panic!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::Mul => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Const(op_cidx), Operand::Var(op_iidx)) - | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x * 0` with `0`. - self.m.replace(iidx, Inst::Const(op_cidx)); - } - Const::Int(_, 1) => { - // Replace `x * 1` with `x`. + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::LShr => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + if let Const::Int(_, 0) = self.m.const_(op_cidx) { + // Replace `x >> 0` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } - Const::Int(ty_idx, x) if x.is_power_of_two() => { - // Replace `x * y` with `x << ...`. - let shl = u64::from(x.ilog2()); - let shl_op = - Operand::Const(self.m.insert_const(Const::Int(*ty_idx, shl))?); - let new_inst = - BinOpInst::new(Operand::Var(op_iidx), BinOp::Shl, shl_op) - .into(); - self.m.replace(iidx, new_inst); - } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::Mul, - Operand::Const(op_cidx), - ) - .into(), - ); - } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v.wrapping_mul(*rhs_v)).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Const(op_cidx), Operand::Var(_)) => { + if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { + // Replace `0 >> x` with `0`. + let new_cidx = self.m.insert_const_int(*tyidx, 0)?; + self.m.replace(iidx, Inst::Const(new_cidx)); } - _ => todo!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::Or => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Const(op_cidx), Operand::Var(op_iidx)) - | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x | 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v >> rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::Or, - Operand::Const(op_cidx), - ) - .into(), - ); + } + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::Mul => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x * 0` with `0`. + self.m.replace(iidx, Inst::Const(op_cidx)); + } + Const::Int(_, 1) => { + // Replace `x * 1` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + Const::Int(ty_idx, x) if x.is_power_of_two() => { + // Replace `x * y` with `x << ...`. + let shl = u64::from(x.ilog2()); + let shl_op = Operand::Const( + self.m.insert_const(Const::Int(*ty_idx, shl))?, + ); + let new_inst = + BinOpInst::new(Operand::Var(op_iidx), BinOp::Shl, shl_op) + .into(); + self.m.replace(iidx, new_inst); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::Mul, + Operand::Const(op_cidx), + ) + .into(), + ); + } } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v | rhs_v).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v.wrapping_mul(*rhs_v)).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), } - _ => panic!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::Shl => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x << 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::Or => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x | 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::Or, + Operand::Const(op_cidx), + ) + .into(), + ); + } + } } - } - (Operand::Const(op_cidx), Operand::Var(_)) => { - if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { - // Replace `0 << x` with `0`. - let new_cidx = self.m.insert_const_int(*tyidx, 0)?; - self.m.replace(iidx, Inst::Const(new_cidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v | rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), + } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v << rhs_v).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::Shl => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + if let Const::Int(_, 0) = self.m.const_(op_cidx) { + // Replace `x << 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); } - _ => panic!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::Sub => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - if let Const::Int(_, 0) = self.m.const_(op_cidx) { - // Replace `x - 0` with `x`. - self.m.replace(iidx, Inst::Copy(op_iidx)); + (Operand::Const(op_cidx), Operand::Var(_)) => { + if let Const::Int(tyidx, 0) = self.m.const_(op_cidx) { + // Replace `0 << x` with `0`. + let new_cidx = self.m.insert_const_int(*tyidx, 0)?; + self.m.replace(iidx, Inst::Const(new_cidx)); + } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v.wrapping_sub(*rhs_v)).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v << rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), } - _ => todo!(), } - } - (Operand::Const(_), Operand::Var(_)) | (Operand::Var(_), Operand::Var(_)) => (), - }, - BinOp::Xor => match ( - self.an.op_map(&self.m, x.lhs(&self.m)), - self.an.op_map(&self.m, x.rhs(&self.m)), - ) { - (Operand::Const(op_cidx), Operand::Var(op_iidx)) - | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { - match self.m.const_(op_cidx) { - Const::Int(_, 0) => { - // Replace `x ^ 0` with `x`. + (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::Sub => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + if let Const::Int(_, 0) = self.m.const_(op_cidx) { + // Replace `x - 0` with `x`. self.m.replace(iidx, Inst::Copy(op_iidx)); } - _ => { - // Canonicalise to (Var, Const). - self.m.replace( - iidx, - BinOpInst::new( - Operand::Var(op_iidx), - BinOp::Xor, - Operand::Const(op_cidx), - ) - .into(), - ); - } } - } - (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { - match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { - (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { - debug_assert_eq!(lhs_tyidx, rhs_tyidx); - let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { - panic!() - }; - let cidx = self.m.insert_const_int( - *lhs_tyidx, - (lhs_v ^ rhs_v).truncate(*bits), - )?; - self.m.replace(iidx, Inst::Const(cidx)); + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v.wrapping_sub(*rhs_v)).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), } - _ => panic!(), } - } - (Operand::Var(_), Operand::Var(_)) => (), - }, - _ => { - if let (Operand::Const(_), Operand::Const(_)) = ( + (Operand::Const(_), Operand::Var(_)) + | (Operand::Var(_), Operand::Var(_)) => (), + }, + BinOp::Xor => match ( self.an.op_map(&self.m, x.lhs(&self.m)), self.an.op_map(&self.m, x.rhs(&self.m)), ) { - todo!("{:?}", x.binop()); + (Operand::Const(op_cidx), Operand::Var(op_iidx)) + | (Operand::Var(op_iidx), Operand::Const(op_cidx)) => { + match self.m.const_(op_cidx) { + Const::Int(_, 0) => { + // Replace `x ^ 0` with `x`. + self.m.replace(iidx, Inst::Copy(op_iidx)); + } + _ => { + // Canonicalise to (Var, Const). + self.m.replace( + iidx, + BinOpInst::new( + Operand::Var(op_iidx), + BinOp::Xor, + Operand::Const(op_cidx), + ) + .into(), + ); + } + } + } + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + match (self.m.const_(lhs_cidx), self.m.const_(rhs_cidx)) { + (Const::Int(lhs_tyidx, lhs_v), Const::Int(rhs_tyidx, rhs_v)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { + panic!() + }; + let cidx = self.m.insert_const_int( + *lhs_tyidx, + (lhs_v ^ rhs_v).truncate(*bits), + )?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => panic!(), + } + } + (Operand::Var(_), Operand::Var(_)) => (), + }, + _ => { + if let (Operand::Const(_), Operand::Const(_)) = ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + todo!("{:?}", x.binop()); + } } } - }, + } Inst::DynPtrAdd(x) => { if let Operand::Const(cidx) = self.an.op_map(&self.m, x.num_elems(&self.m)) { let Const::Int(_, v) = self.m.const_(cidx) else { @@ -1858,4 +1863,67 @@ mod test { ", ); } + + #[test] + fn canonicalisation() { + // Those that can be canonicalised + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = add 3i8, %0 + %2: i8 = and 3i8, %0 + %3: i8 = mul 3i8, %0 + %4: i8 = or 3i8, %0 + %5: i8 = xor 3i8, %0 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 +", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + %1: i8 = add %0, 3i8 + %2: i8 = and %0, 3i8 + %3: i8 = mul %0, 3i8 + %4: i8 = or %0, 3i8 + %5: i8 = xor %0, 3i8 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 +", + ); + + // Those that cannot be canonicalised + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = lshr 3i8, %0 + %2: i8 = shl 3i8, %0 + %3: i8 = sub 3i8, %0 + black_box %1 + black_box %2 + black_box %3 +", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + %1: i8 = lshr 3i8, %0 + %2: i8 = shl 3i8, %0 + %3: i8 = sub 3i8, %0 + black_box %1 + black_box %2 + black_box %3 +", + ); + } } From 713e2403830344d4dcda855f56ff4a80acf72798 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 18:09:13 +0000 Subject: [PATCH 8/9] Deal with poison in `Shl` properly. --- ykrt/src/compile/jitc_yk/opt/mod.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index c871063b4..7bdd2474d 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -416,9 +416,18 @@ impl Opt { let Ty::Integer(bits) = self.m.type_(*lhs_tyidx) else { panic!() }; + // If checked_shl fails, we've encountered LLVM poison: we can + // now choose any value (in this case 0) and know that we're + // respecting LLVM's semantics. In case the user's program then + // has UB and uses the poison value, we make it `int::MAX` + // because there is a small chance that will make the UB more + // obvious to them. let cidx = self.m.insert_const_int( *lhs_tyidx, - (lhs_v << rhs_v).truncate(*bits), + (lhs_v + .checked_shl(u32::try_from(*rhs_v).unwrap()) + .unwrap_or(u64::MAX)) + .truncate(*bits), )?; self.m.replace(iidx, Inst::Const(cidx)); } From 4cee45535315f177c0ef307fa410830f78256275 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 21:30:59 +0000 Subject: [PATCH 9/9] Optimise `Select` instructions. --- ykrt/src/compile/jitc_yk/jit_ir/mod.rs | 10 ++++ ykrt/src/compile/jitc_yk/opt/mod.rs | 71 ++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs index ca226c9dd..126ad2781 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs @@ -415,6 +415,16 @@ impl Module { self.insts[usize::from(iidx)] = inst; } + /// Replace the instruction in `iidx` with an instruction that will generate `op`. In other + /// words, `Operand::Var(...)` will become `Inst::Copy` and `Operand::Const` will become + /// `Inst::Const`. This is a convenience function over [Self::replace]. + pub(crate) fn replace_with_op(&mut self, iidx: InstIdx, op: Operand) { + match op { + Operand::Var(op_iidx) => self.replace(iidx, Inst::Copy(op_iidx)), + Operand::Const(cidx) => self.replace(iidx, Inst::Const(cidx)), + } + } + /// Push an instruction to the end of the [Module] and create a local variable [Operand] out of /// the value that the instruction defines. /// diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 7bdd2474d..989ae2be9 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -571,6 +571,24 @@ impl Opt { } } }, + Inst::Select(sinst) => { + if let Operand::Const(cidx) = self.an.op_map(&self.m, sinst.cond(&self.m)) { + let Const::Int(_, v) = self.m.const_(cidx) else { + panic!() + }; + let op = match v { + 0 => sinst.falseval(&self.m), + 1 => sinst.trueval(&self.m), + _ => panic!(), + }; + self.m.replace_with_op(iidx, op); + } else if self.an.op_map(&self.m, sinst.trueval(&self.m)) + == self.an.op_map(&self.m, sinst.falseval(&self.m)) + { + // Both true and false operands are equal, so it doesn't matter which we use. + self.m.replace_with_op(iidx, sinst.trueval(&self.m)); + } + } Inst::SExt(x) => { if let Operand::Const(cidx) = self.an.op_map(&self.m, x.val(&self.m)) { let Const::Int(src_ty, src_val) = self.m.const_(cidx) else { @@ -1437,6 +1455,59 @@ mod test { ); } + #[test] + fn opt_select() { + // Test constant condition. + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = param 1 + %2: i8 = 1i1 ? %0 : %1 + %3: i8 = 0i1 ? %0 : %1 + black_box %2 + black_box %3 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + %1: i8 = param ... + black_box %0 + black_box %1 + ", + ); + + // Test equivalent true/false values. + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = param 0 + %1: i8 = param 1 + %2: i8 = param 2 + %3: i8 = %0 ? 0i8 : 0i8 + %4: i8 = %0 ? %1 : %1 + %5: i8 = %0 ? %1 : %2 + black_box %3 + black_box %4 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i1 = param ... + %1: i8 = param ... + %2: i8 = param ... + %5: i8 = %0 ? %1 : %2 + black_box 0i8 + black_box %1 + black_box %5 + ", + ); + } + #[test] fn opt_zext_const() { Module::assert_ir_transform_eq(