From 65d04257b59060729885a6f0f7c0380b95956c24 Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Wed, 15 Jan 2025 21:30:59 +0000 Subject: [PATCH] Optimise `Select` instructions. --- ykrt/src/compile/jitc_yk/jit_ir/mod.rs | 10 ++++ ykrt/src/compile/jitc_yk/opt/mod.rs | 71 ++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs index ca226c9dd..70c0fa654 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs @@ -415,6 +415,16 @@ impl Module { self.insts[usize::from(iidx)] = inst; } + /// Replace the instruction it `iidx` with an instruction that will generate `op`. In other + /// words, `Operand::Var(...)` will become `Inst::Copy` and `Operand::Const` will become + /// `Inst::Const`. This is a convenience function over [Self::replace]. + pub(crate) fn replace_with_op(&mut self, iidx: InstIdx, op: Operand) { + match op { + Operand::Var(op_iidx) => self.replace(iidx, Inst::Copy(op_iidx)), + Operand::Const(cidx) => self.replace(iidx, Inst::Const(cidx)), + } + } + /// Push an instruction to the end of the [Module] and create a local variable [Operand] out of /// the value that the instruction defines. /// diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index 7bdd2474d..989ae2be9 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -571,6 +571,24 @@ impl Opt { } } }, + Inst::Select(sinst) => { + if let Operand::Const(cidx) = self.an.op_map(&self.m, sinst.cond(&self.m)) { + let Const::Int(_, v) = self.m.const_(cidx) else { + panic!() + }; + let op = match v { + 0 => sinst.falseval(&self.m), + 1 => sinst.trueval(&self.m), + _ => panic!(), + }; + self.m.replace_with_op(iidx, op); + } else if self.an.op_map(&self.m, sinst.trueval(&self.m)) + == self.an.op_map(&self.m, sinst.falseval(&self.m)) + { + // Both true and false operands are equal, so it doesn't matter which we use. + self.m.replace_with_op(iidx, sinst.trueval(&self.m)); + } + } Inst::SExt(x) => { if let Operand::Const(cidx) = self.an.op_map(&self.m, x.val(&self.m)) { let Const::Int(src_ty, src_val) = self.m.const_(cidx) else { @@ -1437,6 +1455,59 @@ mod test { ); } + #[test] + fn opt_select() { + // Test constant condition. + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = param 0 + %1: i8 = param 1 + %2: i8 = 1i1 ? %0 : %1 + %3: i8 = 0i1 ? %0 : %1 + black_box %2 + black_box %3 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = param ... + %1: i8 = param ... + black_box %0 + black_box %1 + ", + ); + + // Test equivalent true/false values. + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = param 0 + %1: i8 = param 1 + %2: i8 = param 2 + %3: i8 = %0 ? 0i8 : 0i8 + %4: i8 = %0 ? %1 : %1 + %5: i8 = %0 ? %1 : %2 + black_box %3 + black_box %4 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i1 = param ... + %1: i8 = param ... + %2: i8 = param ... + %5: i8 = %0 ? %1 : %2 + black_box 0i8 + black_box %1 + black_box %5 + ", + ); + } + #[test] fn opt_zext_const() { Module::assert_ir_transform_eq(