diff --git a/src/jit.rs b/src/jit.rs index d5476f491..95bb88403 100644 --- a/src/jit.rs +++ b/src/jit.rs @@ -30,7 +30,7 @@ use crate::{ memory_management::{ allocate_pages, free_pages, get_system_page_size, protect_pages, round_to_page_size, }, - memory_region::{AccessType, MemoryMapping}, + memory_region::MemoryMapping, program::BuiltinFunction, vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm}, x86::*, @@ -38,7 +38,7 @@ use crate::{ const MAX_EMPTY_PROGRAM_MACHINE_CODE_LENGTH: usize = 4096; const MAX_MACHINE_CODE_LENGTH_PER_INSTRUCTION: usize = 110; -const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 23; +const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 24; const MAX_START_PADDING_LENGTH: usize = 256; pub struct JitProgram { @@ -200,7 +200,7 @@ const ANCHOR_INTERNAL_FUNCTION_CALL_PROLOGUE: usize = 12; const ANCHOR_INTERNAL_FUNCTION_CALL_REG: usize = 13; const ANCHOR_CALL_REG_UNSUPPORTED_INSTRUCTION: usize = 14; const ANCHOR_TRANSLATE_MEMORY_ADDRESS: usize = 21; -const ANCHOR_COUNT: usize = 30; // Update me when adding or removing anchors +const ANCHOR_COUNT: usize = 34; // Update me when adding or removing anchors const REGISTER_MAP: [u8; 11] = [ CALLER_SAVED_REGISTERS[0], // RAX @@ -328,6 +328,7 @@ pub struct JitCompiler<'a, C: ContextObject> { next_noop_insertion: u32, noop_range: Uniform, runtime_environment_key: i32, + immediate_value_key: i64, diversification_rng: SmallRng, stopwatch_is_active: bool, } @@ -365,6 +366,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { let runtime_environment_key = get_runtime_environment_key(); let mut diversification_rng = SmallRng::from_rng(thread_rng()).map_err(|_| EbpfError::JitNotCompiled)?; + let immediate_value_key = diversification_rng.gen::(); Ok(Self { result: JitProgram::new(pc, code_length_estimate)?, @@ -380,6 +382,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { next_noop_insertion: if config.noop_instruction_rate == 0 { u32::MAX } else { diversification_rng.gen_range(0..config.noop_instruction_rate * 2) }, noop_range: Uniform::new_inclusive(0, config.noop_instruction_rate * 2), runtime_environment_key, + immediate_value_key, diversification_rng, stopwatch_is_active: false, }) @@ -873,29 +876,24 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { #[inline] fn emit_sanitized_load_immediate(&mut self, destination: u8, value: i64) { + let lower_key = self.immediate_value_key as i32 as i64; if value >= i32::MIN as i64 && value <= i32::MAX as i64 { - let key = self.diversification_rng.gen::() as i64; - self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(key))); - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, key, None)); + self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key))); + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key) } else if value as u64 & u32::MAX as u64 == 0 { - let key = self.diversification_rng.gen::() as i64; - self.emit_ins(X86Instruction::load_immediate(destination, value.rotate_right(32).wrapping_sub(key))); - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, key, None)); // wrapping_add(key) + self.emit_ins(X86Instruction::load_immediate(destination, value.rotate_right(32).wrapping_sub(lower_key))); + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key) self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 4, destination, 32, None)); // shift_left(32) + } else if destination != REGISTER_SCRATCH { + self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(self.immediate_value_key))); + self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, self.immediate_value_key)); + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, destination, 0, None)); // wrapping_add(immediate_value_key) } else { - let key = self.diversification_rng.gen::(); - if destination != REGISTER_SCRATCH { - self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(key))); - self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, key)); - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, destination, 0, None)); - } else { - let lower_key = key as i32 as i64; - let upper_key = (key >> 32) as i32 as i64; - self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key).rotate_right(32).wrapping_sub(upper_key))); - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, upper_key, None)); // wrapping_add(upper_key) - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 1, destination, 32, None)); // rotate_right(32) - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key) - } + let upper_key = (self.immediate_value_key >> 32) as i32 as i64; + self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key).rotate_right(32).wrapping_sub(upper_key))); + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, upper_key, None)); // wrapping_add(upper_key) + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 1, destination, 32, None)); // rotate_right(32) + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key) } } @@ -1157,11 +1155,10 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { self.emit_ins(X86Instruction::store(OperandSize::S64, reg, RSP, stack_slot_of_value_to_store)); } Some(Value::Constant64(constant, user_provided)) => { - if user_provided && self.should_sanitize_constant(constant) { - self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant); - } else { - self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant)); - } + debug_assert!(user_provided); + // First half of emit_sanitized_load_immediate(stack_slot_of_value_to_store, constant) + let lower_key = self.immediate_value_key as i32 as i64; + self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant.wrapping_sub(lower_key))); self.emit_ins(X86Instruction::store(OperandSize::S64, REGISTER_SCRATCH, RSP, stack_slot_of_value_to_store)); } _ => {} @@ -1169,19 +1166,16 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { match vm_addr { Value::RegisterPlusConstant64(reg, constant, user_provided) => { - if user_provided && self.should_sanitize_constant(constant) { - self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant); - } else { - self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant)); - } - self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, reg, REGISTER_SCRATCH, 0, None)); - }, - Value::Constant64(constant, user_provided) => { - if user_provided && self.should_sanitize_constant(constant) { - self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant); - } else { - self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant)); - } + debug_assert!(user_provided); + // First half of emit_sanitized_load_immediate(REGISTER_SCRATCH, vm_addr) + let lower_key = self.immediate_value_key as i32 as i64; + self.emit_ins(X86Instruction::lea(OperandSize::S64, reg, REGISTER_SCRATCH, Some( + if reg == R12 { + X86IndirectAccess::OffsetIndexShift(constant.wrapping_sub(lower_key) as i32, RSP, 0) + } else { + X86IndirectAccess::Offset(constant.wrapping_sub(lower_key) as i32) + } + ))); }, _ => { #[cfg(debug_assertions)] @@ -1190,8 +1184,12 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { } if self.config.enable_address_translation { - let access_type = if value.is_none() { AccessType::Load } else { AccessType::Store }; - let anchor = ANCHOR_TRANSLATE_MEMORY_ADDRESS + len.trailing_zeros() as usize + 4 * (access_type as usize); + let anchor_base = match value { + Some(Value::Register(_reg)) => 4, + Some(Value::Constant64(_constant, _user_provided)) => 8, + _ => 0, + }; + let anchor = ANCHOR_TRANSLATE_MEMORY_ADDRESS + anchor_base + len.trailing_zeros() as usize; self.emit_ins(X86Instruction::push_immediate(OperandSize::S64, self.pc as i32)); self.emit_ins(X86Instruction::call_immediate(self.relative_to_anchor(anchor, 5))); if let Some(dst) = dst { @@ -1600,20 +1598,18 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { self.emit_ins(X86Instruction::jump_immediate(self.relative_to_anchor(ANCHOR_CALL_UNSUPPORTED_INSTRUCTION, 5))); // Translates a vm memory address to a host memory address - for (access_type, len) in &[ - (AccessType::Load, 1i32), - (AccessType::Load, 2i32), - (AccessType::Load, 4i32), - (AccessType::Load, 8i32), - (AccessType::Store, 1i32), - (AccessType::Store, 2i32), - (AccessType::Store, 4i32), - (AccessType::Store, 8i32), + let lower_key = self.immediate_value_key as i32 as i64; + for (anchor_base, len) in &[ + (0, 1i32), (0, 2i32), (0, 4i32), (0, 8i32), + (4, 1i32), (4, 2i32), (4, 4i32), (4, 8i32), + (8, 1i32), (8, 2i32), (8, 4i32), (8, 8i32), ] { - let target_offset = len.trailing_zeros() as usize + 4 * (*access_type as usize); + let target_offset = *anchor_base + len.trailing_zeros() as usize; self.set_anchor(ANCHOR_TRANSLATE_MEMORY_ADDRESS + target_offset); + // Second half of emit_sanitized_load_immediate(REGISTER_SCRATCH, vm_addr) + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, REGISTER_SCRATCH, lower_key, None)); // call MemoryMapping::(load|store) storing the result in RuntimeEnvironmentSlot::ProgramResult - if *access_type == AccessType::Load { + if *anchor_base == 0 { // AccessType::Load let load = match len { 1 => MemoryMapping::load:: as *const u8 as i64, 2 => MemoryMapping::load:: as *const u8 as i64, @@ -1627,7 +1623,11 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { Argument { index: 1, value: Value::RegisterPlusConstant32(REGISTER_PTR_TO_VM, self.slot_in_vm(RuntimeEnvironmentSlot::MemoryMapping), false) }, Argument { index: 0, value: Value::RegisterPlusConstant32(REGISTER_PTR_TO_VM, self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult), false) }, ], None); - } else { + } else { // AccessType::Store + if *anchor_base == 8 { + // Second half of emit_sanitized_load_immediate(stack_slot_of_value_to_store, constant) + self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, RSP, lower_key, Some(X86IndirectAccess::OffsetIndexShift(-96, RSP, 0)))); + } let store = match len { 1 => MemoryMapping::store:: as *const u8 as i64, 2 => MemoryMapping::store:: as *const u8 as i64, @@ -1650,8 +1650,10 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> { self.emit_ins(X86Instruction::xchg(OperandSize::S64, REGISTER_SCRATCH, RSP, Some(X86IndirectAccess::OffsetIndexShift(0, RSP, 0)))); // Swap return address and self.pc self.emit_ins(X86Instruction::conditional_jump_immediate(0x85, self.relative_to_anchor(ANCHOR_THROW_EXCEPTION, 6))); - // unwrap() the result into REGISTER_SCRATCH - self.emit_ins(X86Instruction::load(OperandSize::S64, REGISTER_PTR_TO_VM, REGISTER_SCRATCH, X86IndirectAccess::Offset(self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult) + std::mem::size_of::() as i32))); + if *anchor_base == 0 { // AccessType::Load + // unwrap() the result into REGISTER_SCRATCH + self.emit_ins(X86Instruction::load(OperandSize::S64, REGISTER_PTR_TO_VM, REGISTER_SCRATCH, X86IndirectAccess::Offset(self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult) + std::mem::size_of::() as i32))); + } self.emit_ins(X86Instruction::return_near()); }