From 2205ebaef1ba2999be7aa58c31c1753c4e9c679d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Fri, 13 Dec 2024 16:55:31 +0100 Subject: [PATCH] Refactor - Sanitize IP/PC immediate values (#1) * Always sanitize emit_profile_instruction_count() and emit_undo_profile_instruction_count(). * Sanitize emit_validate_instruction_count() as well. --- src/jit.rs | 62 ++++++++++++++++++++++++------------------------------ 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/src/jit.rs b/src/jit.rs index 7d0e4570..f151d2c3 100644 --- a/src/jit.rs +++ b/src/jit.rs @@ -38,7 +38,7 @@ use crate::{ const MAX_EMPTY_PROGRAM_MACHINE_CODE_LENGTH: usize = 4096; const MAX_MACHINE_CODE_LENGTH_PER_INSTRUCTION: usize = 110; -const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 13; +const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 23; const MAX_START_PADDING_LENGTH: usize = 256; pub struct JitProgram { @@ -423,7 +423,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { match insn.opc { ebpf::LD_DW_IMM if !self.executable.get_sbpf_version().disable_lddw() => { - self.emit_validate_and_profile_instruction_count(false, Some(self.pc + 2)); + self.emit_validate_and_profile_instruction_count(Some(self.pc + 2)); self.pc += 1; self.result.pc_section[self.pc] = self.anchors[ANCHOR_CALL_UNSUPPORTED_INSTRUCTION] as usize; ebpf::augment_lddw_unchecked(self.program, &mut insn); @@ -702,7 +702,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { // BPF_JMP class ebpf::JA => { - self.emit_validate_and_profile_instruction_count(true, Some(target_pc)); + self.emit_validate_and_profile_instruction_count(Some(target_pc)); self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, target_pc as i64)); let jump_offset = self.relative_to_target_pc(target_pc, 5); self.emit_ins(X86Instruction::jump_immediate(jump_offset)); @@ -788,7 +788,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 5, REGISTER_PTR_TO_VM, 1, Some(call_depth_access))); // env.call_depth -= 1; // and return - self.emit_profile_instruction_count(false, Some(0)); + self.emit_profile_instruction_count(Some(0)); self.emit_ins(X86Instruction::return_near()); }, @@ -802,7 +802,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { if self.offset_in_text_section + MAX_MACHINE_CODE_LENGTH_PER_INSTRUCTION * 2 >= self.result.text_section.len() { return Err(EbpfError::ExhaustedTextSegment(self.pc)); } - self.emit_validate_and_profile_instruction_count(false, Some(self.pc + 1)); + self.emit_validate_and_profile_instruction_count(Some(self.pc + 1)); self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, self.pc as i64)); // Save pc self.emit_set_exception_kind(EbpfError::ExecutionOverrun); self.emit_ins(X86Instruction::jump_immediate(self.relative_to_anchor(ANCHOR_THROW_EXCEPTION, 5))); @@ -941,49 +941,35 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { // Update `MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT` if you change the code generation here if let Some(pc) = pc { self.last_instruction_meter_validation_pc = pc; - // instruction_meter >= self.pc - self.emit_ins(X86Instruction::cmp_immediate(OperandSize::S64, REGISTER_INSTRUCTION_METER, pc as i64, None)); - } else { - // instruction_meter >= scratch_register - self.emit_ins(X86Instruction::cmp(OperandSize::S64, REGISTER_SCRATCH, REGISTER_INSTRUCTION_METER, None)); + self.emit_sanitized_load_immediate(REGISTER_SCRATCH, pc as i64); } + // If instruction_meter >= pc, throw ExceededMaxInstructions + self.emit_ins(X86Instruction::cmp(OperandSize::S64, REGISTER_SCRATCH, REGISTER_INSTRUCTION_METER, None)); self.emit_ins(X86Instruction::conditional_jump_immediate(0x86, self.relative_to_anchor(ANCHOR_THROW_EXCEEDED_MAX_INSTRUCTIONS, 6))); } #[inline] - fn emit_profile_instruction_count(&mut self, user_provided: bool, target_pc: Option) { + fn emit_profile_instruction_count(&mut self, target_pc: Option) { if !self.config.enable_instruction_meter { return; } match target_pc { Some(target_pc) => { - // instruction_meter += target_pc - (self.pc + 1); - let immediate = target_pc as i64 - self.pc as i64 - 1; - if user_provided { - self.emit_sanitized_alu(OperandSize::S64, 0x01, 0, REGISTER_INSTRUCTION_METER, immediate); - } else { - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, REGISTER_INSTRUCTION_METER, immediate, None)); - } + self.emit_sanitized_alu(OperandSize::S64, 0x01, 0, REGISTER_INSTRUCTION_METER, target_pc as i64 - self.pc as i64 - 1); // instruction_meter += target_pc - (self.pc + 1); }, None => { - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 5, REGISTER_INSTRUCTION_METER, self.pc as i64 + 1, None)); // instruction_meter -= self.pc + 1; self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, REGISTER_INSTRUCTION_METER, 0, None)); // instruction_meter += target_pc; + self.emit_sanitized_alu(OperandSize::S64, 0x81, 5, REGISTER_INSTRUCTION_METER, self.pc as i64 + 1); // instruction_meter -= self.pc + 1; }, } } - #[inline] - fn emit_validate_and_profile_instruction_count(&mut self, user_provided: bool, target_pc: Option) { - self.emit_validate_instruction_count(Some(self.pc)); - self.emit_profile_instruction_count(user_provided, target_pc); - } - #[inline] fn emit_undo_profile_instruction_count(&mut self, target_pc: Value) { if self.config.enable_instruction_meter { match target_pc { Value::Constant64(target_pc, _) => { - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, REGISTER_INSTRUCTION_METER, self.pc as i64 + 1 - target_pc, None)); // instruction_meter += (self.pc + 1) - target_pc; + self.emit_sanitized_alu(OperandSize::S64, 0x01, 0, REGISTER_INSTRUCTION_METER, self.pc as i64 + 1 - target_pc); // instruction_meter += (self.pc + 1) - target_pc; } Value::Register(target_pc) => { self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x29, target_pc, REGISTER_INSTRUCTION_METER, 0, None)); // instruction_meter -= guest_target_pc @@ -995,6 +981,12 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { } } + #[inline] + fn emit_validate_and_profile_instruction_count(&mut self, target_pc: Option) { + self.emit_validate_instruction_count(Some(self.pc)); + self.emit_profile_instruction_count(target_pc); + } + fn emit_rust_call(&mut self, target: Value, arguments: &[Argument], result_reg: Option) { let mut saved_registers = CALLER_SAVED_REGISTERS.to_vec(); if let Some(reg) = result_reg { @@ -1123,7 +1115,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { }, Value::Constant64(target_pc, user_provided) => { debug_assert!(user_provided); - self.emit_profile_instruction_count(user_provided, Some(target_pc as usize)); + self.emit_profile_instruction_count(Some(target_pc as usize)); if user_provided && self.should_sanitize_constant(target_pc) { self.emit_sanitized_load_immediate(REGISTER_SCRATCH, target_pc); } else { @@ -1149,7 +1141,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { #[inline] fn emit_syscall_dispatch(&mut self, function: BuiltinFunction) { - self.emit_validate_and_profile_instruction_count(false, Some(0)); + self.emit_validate_and_profile_instruction_count(Some(0)); self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, function as usize as i64)); self.emit_ins(X86Instruction::call_immediate(self.relative_to_anchor(ANCHOR_EXTERNAL_FUNCTION_CALL, 5))); self.emit_undo_profile_instruction_count(Value::Constant64(0, false)); @@ -1228,7 +1220,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { #[inline] fn emit_conditional_branch_reg(&mut self, op: u8, bitwise: bool, first_operand: u8, second_operand: u8, target_pc: usize) { - self.emit_validate_and_profile_instruction_count(true, Some(target_pc)); + self.emit_validate_and_profile_instruction_count(Some(target_pc)); if bitwise { // Logical self.emit_ins(X86Instruction::test(OperandSize::S64, first_operand, second_operand, None)); } else { // Arithmetic @@ -1242,7 +1234,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { #[inline] fn emit_conditional_branch_imm(&mut self, op: u8, bitwise: bool, immediate: i64, second_operand: u8, target_pc: usize) { - self.emit_validate_and_profile_instruction_count(true, Some(target_pc)); + self.emit_validate_and_profile_instruction_count(Some(target_pc)); if self.should_sanitize_constant(immediate) { self.emit_sanitized_load_immediate(REGISTER_SCRATCH, immediate); if bitwise { // Logical @@ -1578,7 +1570,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { let number_of_instructions = self.result.pc_section.len(); self.emit_ins(X86Instruction::cmp_immediate(OperandSize::S64, REGISTER_SCRATCH, (number_of_instructions * INSN_SIZE) as i64, None)); // guest_target_address.cmp(number_of_instructions * INSN_SIZE) self.emit_ins(X86Instruction::conditional_jump_immediate(0x83, self.relative_to_anchor(ANCHOR_CALL_OUTSIDE_TEXT_SEGMENT, 6))); - // First half of self.emit_profile_instruction_count(false, None); + // First half of self.emit_profile_instruction_count(None); self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x2b, REGISTER_INSTRUCTION_METER, RSP, 0, Some(X86IndirectAccess::OffsetIndexShift(-8, RSP, 0)))); // instruction_meter -= guest_current_pc; self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 5, REGISTER_INSTRUCTION_METER, 1, None)); // instruction_meter -= 1; // Load host target_address from self.result.pc_section @@ -1591,7 +1583,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { let shift_amount = INSN_SIZE.trailing_zeros(); debug_assert_eq!(INSN_SIZE, 1 << shift_amount); self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 5, REGISTER_SCRATCH, shift_amount as i64, None)); // guest_target_pc /= INSN_SIZE; - // Second half of self.emit_profile_instruction_count(false, None); + // Second half of self.emit_profile_instruction_count(None); self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, REGISTER_INSTRUCTION_METER, 0, None)); // instruction_meter += guest_target_pc; // Restore the clobbered REGISTER_MAP[0] self.emit_ins(X86Instruction::xchg(OperandSize::S64, REGISTER_MAP[0], RSP, Some(X86IndirectAccess::OffsetIndexShift(0, RSP, 0)))); // Swap REGISTER_MAP[0] and host_target_address @@ -1841,9 +1833,9 @@ mod tests { let instruction_meter_checkpoint_machine_code_length = instruction_meter_checkpoint_machine_code_length[0] - instruction_meter_checkpoint_machine_code_length[1]; - assert_eq!( - instruction_meter_checkpoint_machine_code_length, - MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT + assert!( + instruction_meter_checkpoint_machine_code_length + <= MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT ); for sbpf_version in [SBPFVersion::V0, SBPFVersion::V3] {