Skip to content

Commit

Permalink
Refactor - Sanitization of memory accesses in JIT (#6)
Browse files Browse the repository at this point in the history
* All the immediate values can use the same encryption key.
There is no need to generate a new one for each.

* Optimizes memory access instructions.

* Moves second half of emit_sanitized_load_immediate(REGISTER_SCRATCH, vm_addr) into ANCHOR_TRANSLATE_MEMORY_ADDRESS.

* Moves second half of emit_sanitized_load_immediate(stack_slot_of_value_to_store, constant) into ANCHOR_TRANSLATE_MEMORY_ADDRESS.
  • Loading branch information
Lichtso authored Dec 20, 2024
1 parent a5cdab3 commit cad781a
Showing 1 changed file with 58 additions and 56 deletions.
114 changes: 58 additions & 56 deletions src/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ use crate::{
memory_management::{
allocate_pages, free_pages, get_system_page_size, protect_pages, round_to_page_size,
},
memory_region::{AccessType, MemoryMapping},
memory_region::MemoryMapping,
program::BuiltinFunction,
vm::{get_runtime_environment_key, Config, ContextObject, EbpfVm},
x86::*,
};

const MAX_EMPTY_PROGRAM_MACHINE_CODE_LENGTH: usize = 4096;
const MAX_MACHINE_CODE_LENGTH_PER_INSTRUCTION: usize = 110;
const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 23;
const MACHINE_CODE_PER_INSTRUCTION_METER_CHECKPOINT: usize = 24;
const MAX_START_PADDING_LENGTH: usize = 256;

pub struct JitProgram {
Expand Down Expand Up @@ -200,7 +200,7 @@ const ANCHOR_INTERNAL_FUNCTION_CALL_PROLOGUE: usize = 12;
const ANCHOR_INTERNAL_FUNCTION_CALL_REG: usize = 13;
const ANCHOR_CALL_REG_UNSUPPORTED_INSTRUCTION: usize = 14;
const ANCHOR_TRANSLATE_MEMORY_ADDRESS: usize = 21;
const ANCHOR_COUNT: usize = 30; // Update me when adding or removing anchors
const ANCHOR_COUNT: usize = 34; // Update me when adding or removing anchors

const REGISTER_MAP: [u8; 11] = [
CALLER_SAVED_REGISTERS[0], // RAX
Expand Down Expand Up @@ -328,6 +328,7 @@ pub struct JitCompiler<'a, C: ContextObject> {
next_noop_insertion: u32,
noop_range: Uniform<u32>,
runtime_environment_key: i32,
immediate_value_key: i64,
diversification_rng: SmallRng,
stopwatch_is_active: bool,
}
Expand Down Expand Up @@ -365,6 +366,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {

let runtime_environment_key = get_runtime_environment_key();
let mut diversification_rng = SmallRng::from_rng(thread_rng()).map_err(|_| EbpfError::JitNotCompiled)?;
let immediate_value_key = diversification_rng.gen::<i64>();

Ok(Self {
result: JitProgram::new(pc, code_length_estimate)?,
Expand All @@ -380,6 +382,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
next_noop_insertion: if config.noop_instruction_rate == 0 { u32::MAX } else { diversification_rng.gen_range(0..config.noop_instruction_rate * 2) },
noop_range: Uniform::new_inclusive(0, config.noop_instruction_rate * 2),
runtime_environment_key,
immediate_value_key,
diversification_rng,
stopwatch_is_active: false,
})
Expand Down Expand Up @@ -873,29 +876,24 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {

#[inline]
fn emit_sanitized_load_immediate(&mut self, destination: u8, value: i64) {
let lower_key = self.immediate_value_key as i32 as i64;
if value >= i32::MIN as i64 && value <= i32::MAX as i64 {
let key = self.diversification_rng.gen::<i32>() as i64;
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, key, None));
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key)
} else if value as u64 & u32::MAX as u64 == 0 {
let key = self.diversification_rng.gen::<i32>() as i64;
self.emit_ins(X86Instruction::load_immediate(destination, value.rotate_right(32).wrapping_sub(key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, key, None)); // wrapping_add(key)
self.emit_ins(X86Instruction::load_immediate(destination, value.rotate_right(32).wrapping_sub(lower_key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 4, destination, 32, None)); // shift_left(32)
} else if destination != REGISTER_SCRATCH {
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(self.immediate_value_key)));
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, self.immediate_value_key));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, destination, 0, None)); // wrapping_add(immediate_value_key)
} else {
let key = self.diversification_rng.gen::<i64>();
if destination != REGISTER_SCRATCH {
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(key)));
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, key));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, destination, 0, None));
} else {
let lower_key = key as i32 as i64;
let upper_key = (key >> 32) as i32 as i64;
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key).rotate_right(32).wrapping_sub(upper_key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, upper_key, None)); // wrapping_add(upper_key)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 1, destination, 32, None)); // rotate_right(32)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key)
}
let upper_key = (self.immediate_value_key >> 32) as i32 as i64;
self.emit_ins(X86Instruction::load_immediate(destination, value.wrapping_sub(lower_key).rotate_right(32).wrapping_sub(upper_key)));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, upper_key, None)); // wrapping_add(upper_key)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0xc1, 1, destination, 32, None)); // rotate_right(32)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, destination, lower_key, None)); // wrapping_add(lower_key)
}
}

Expand Down Expand Up @@ -1157,31 +1155,27 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
self.emit_ins(X86Instruction::store(OperandSize::S64, reg, RSP, stack_slot_of_value_to_store));
}
Some(Value::Constant64(constant, user_provided)) => {
if user_provided && self.should_sanitize_constant(constant) {
self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant);
} else {
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant));
}
debug_assert!(user_provided);
// First half of emit_sanitized_load_immediate(stack_slot_of_value_to_store, constant)
let lower_key = self.immediate_value_key as i32 as i64;
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant.wrapping_sub(lower_key)));
self.emit_ins(X86Instruction::store(OperandSize::S64, REGISTER_SCRATCH, RSP, stack_slot_of_value_to_store));
}
_ => {}
}

match vm_addr {
Value::RegisterPlusConstant64(reg, constant, user_provided) => {
if user_provided && self.should_sanitize_constant(constant) {
self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant);
} else {
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant));
}
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, reg, REGISTER_SCRATCH, 0, None));
},
Value::Constant64(constant, user_provided) => {
if user_provided && self.should_sanitize_constant(constant) {
self.emit_sanitized_load_immediate(REGISTER_SCRATCH, constant);
} else {
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, constant));
}
debug_assert!(user_provided);
// First half of emit_sanitized_load_immediate(REGISTER_SCRATCH, vm_addr)
let lower_key = self.immediate_value_key as i32 as i64;
self.emit_ins(X86Instruction::lea(OperandSize::S64, reg, REGISTER_SCRATCH, Some(
if reg == R12 {
X86IndirectAccess::OffsetIndexShift(constant.wrapping_sub(lower_key) as i32, RSP, 0)
} else {
X86IndirectAccess::Offset(constant.wrapping_sub(lower_key) as i32)
}
)));
},
_ => {
#[cfg(debug_assertions)]
Expand All @@ -1190,8 +1184,12 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
}

if self.config.enable_address_translation {
let access_type = if value.is_none() { AccessType::Load } else { AccessType::Store };
let anchor = ANCHOR_TRANSLATE_MEMORY_ADDRESS + len.trailing_zeros() as usize + 4 * (access_type as usize);
let anchor_base = match value {
Some(Value::Register(_reg)) => 4,
Some(Value::Constant64(_constant, _user_provided)) => 8,
_ => 0,
};
let anchor = ANCHOR_TRANSLATE_MEMORY_ADDRESS + anchor_base + len.trailing_zeros() as usize;
self.emit_ins(X86Instruction::push_immediate(OperandSize::S64, self.pc as i32));
self.emit_ins(X86Instruction::call_immediate(self.relative_to_anchor(anchor, 5)));
if let Some(dst) = dst {
Expand Down Expand Up @@ -1600,20 +1598,18 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
self.emit_ins(X86Instruction::jump_immediate(self.relative_to_anchor(ANCHOR_CALL_UNSUPPORTED_INSTRUCTION, 5)));

// Translates a vm memory address to a host memory address
for (access_type, len) in &[
(AccessType::Load, 1i32),
(AccessType::Load, 2i32),
(AccessType::Load, 4i32),
(AccessType::Load, 8i32),
(AccessType::Store, 1i32),
(AccessType::Store, 2i32),
(AccessType::Store, 4i32),
(AccessType::Store, 8i32),
let lower_key = self.immediate_value_key as i32 as i64;
for (anchor_base, len) in &[
(0, 1i32), (0, 2i32), (0, 4i32), (0, 8i32),
(4, 1i32), (4, 2i32), (4, 4i32), (4, 8i32),
(8, 1i32), (8, 2i32), (8, 4i32), (8, 8i32),
] {
let target_offset = len.trailing_zeros() as usize + 4 * (*access_type as usize);
let target_offset = *anchor_base + len.trailing_zeros() as usize;
self.set_anchor(ANCHOR_TRANSLATE_MEMORY_ADDRESS + target_offset);
// Second half of emit_sanitized_load_immediate(REGISTER_SCRATCH, vm_addr)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, REGISTER_SCRATCH, lower_key, None));
// call MemoryMapping::(load|store) storing the result in RuntimeEnvironmentSlot::ProgramResult
if *access_type == AccessType::Load {
if *anchor_base == 0 { // AccessType::Load
let load = match len {
1 => MemoryMapping::load::<u8> as *const u8 as i64,
2 => MemoryMapping::load::<u16> as *const u8 as i64,
Expand All @@ -1627,7 +1623,11 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
Argument { index: 1, value: Value::RegisterPlusConstant32(REGISTER_PTR_TO_VM, self.slot_in_vm(RuntimeEnvironmentSlot::MemoryMapping), false) },
Argument { index: 0, value: Value::RegisterPlusConstant32(REGISTER_PTR_TO_VM, self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult), false) },
], None);
} else {
} else { // AccessType::Store
if *anchor_base == 8 {
// Second half of emit_sanitized_load_immediate(stack_slot_of_value_to_store, constant)
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, RSP, lower_key, Some(X86IndirectAccess::OffsetIndexShift(-96, RSP, 0))));
}
let store = match len {
1 => MemoryMapping::store::<u8> as *const u8 as i64,
2 => MemoryMapping::store::<u16> as *const u8 as i64,
Expand All @@ -1650,8 +1650,10 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
self.emit_ins(X86Instruction::xchg(OperandSize::S64, REGISTER_SCRATCH, RSP, Some(X86IndirectAccess::OffsetIndexShift(0, RSP, 0)))); // Swap return address and self.pc
self.emit_ins(X86Instruction::conditional_jump_immediate(0x85, self.relative_to_anchor(ANCHOR_THROW_EXCEPTION, 6)));

// unwrap() the result into REGISTER_SCRATCH
self.emit_ins(X86Instruction::load(OperandSize::S64, REGISTER_PTR_TO_VM, REGISTER_SCRATCH, X86IndirectAccess::Offset(self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult) + std::mem::size_of::<u64>() as i32)));
if *anchor_base == 0 { // AccessType::Load
// unwrap() the result into REGISTER_SCRATCH
self.emit_ins(X86Instruction::load(OperandSize::S64, REGISTER_PTR_TO_VM, REGISTER_SCRATCH, X86IndirectAccess::Offset(self.slot_in_vm(RuntimeEnvironmentSlot::ProgramResult) + std::mem::size_of::<u64>() as i32)));
}

self.emit_ins(X86Instruction::return_near());
}
Expand Down

0 comments on commit cad781a

Please sign in to comment.