Skip to content

Commit

Permalink
Refactor - Halves memory needed for pc_section of JIT compiled prog…
Browse files Browse the repository at this point in the history
…rams (#11)

* Address offsetting via indirection of the load instruction.

* Groups the instructions of `self.emit_profile_instruction_count(None);`.

* Stores text_section relative offsets in pc_section instead of absolute addresses.

* Stores u32 instead of u64 elements in pc_section.

* Spills to MMX register and uses 64 bit load immediate.
  • Loading branch information
Lichtso authored Jan 2, 2025
1 parent e09f773 commit 085a488
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions src/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ const MAX_START_PADDING_LENGTH: usize = 256;
pub struct JitProgram {
/// OS page size in bytes and the alignment of the sections
page_size: usize,
/// A `*const u8` pointer into the text_section for each BPF instruction
pc_section: &'static mut [usize],
/// Byte offset in the text_section for each BPF instruction
pc_section: &'static mut [u32],
/// The x86 machinecode
text_section: &'static mut [u8],
}

impl JitProgram {
fn new(pc: usize, code_size: usize) -> Result<Self, EbpfError> {
let page_size = get_system_page_size();
let pc_loc_table_size = round_to_page_size(pc * 8, page_size);
let pc_loc_table_size = round_to_page_size(pc * std::mem::size_of::<u32>(), page_size);
let over_allocated_code_size = round_to_page_size(code_size, page_size);
unsafe {
let raw = allocate_pages(pc_loc_table_size + over_allocated_code_size)?;
Ok(Self {
page_size,
pc_section: std::slice::from_raw_parts_mut(raw.cast::<usize>(), pc),
pc_section: std::slice::from_raw_parts_mut(raw.cast::<u32>(), pc),
text_section: std::slice::from_raw_parts_mut(
raw.add(pc_loc_table_size),
over_allocated_code_size,
Expand All @@ -77,7 +77,8 @@ impl JitProgram {
return Ok(());
}
let raw = self.pc_section.as_ptr() as *mut u8;
let pc_loc_table_size = round_to_page_size(self.pc_section.len() * 8, self.page_size);
let pc_loc_table_size =
round_to_page_size(std::mem::size_of_val(self.pc_section), self.page_size);
let over_allocated_code_size = round_to_page_size(self.text_section.len(), self.page_size);
let code_size = round_to_page_size(text_section_usage, self.page_size);
unsafe {
Expand Down Expand Up @@ -139,7 +140,7 @@ impl JitProgram {
host_stack_pointer = in(reg) &mut vm.host_stack_pointer,
inlateout("rdi") std::ptr::addr_of_mut!(*vm).cast::<u64>().offset(get_runtime_environment_key() as isize) => _,
inlateout("r10") (vm.previous_instruction_meter as i64).wrapping_add(registers[11] as i64) => _,
inlateout("rax") self.pc_section[registers[11] as usize] => _,
inlateout("rax") &self.text_section[self.pc_section[registers[11] as usize] as usize] as *const u8 => _,
inlateout("r11") &registers => _,
lateout("rsi") _, lateout("rdx") _, lateout("rcx") _, lateout("r8") _,
lateout("r9") _, lateout("r12") _, lateout("r13") _, lateout("r14") _, lateout("r15") _,
Expand All @@ -153,15 +154,17 @@ impl JitProgram {
}

pub fn mem_size(&self) -> usize {
let pc_loc_table_size = round_to_page_size(self.pc_section.len() * 8, self.page_size);
let pc_loc_table_size =
round_to_page_size(std::mem::size_of_val(self.pc_section), self.page_size);
let code_size = round_to_page_size(self.text_section.len(), self.page_size);
pc_loc_table_size + code_size
}
}

impl Drop for JitProgram {
fn drop(&mut self) {
let pc_loc_table_size = round_to_page_size(self.pc_section.len() * 8, self.page_size);
let pc_loc_table_size =
round_to_page_size(std::mem::size_of_val(self.pc_section), self.page_size);
let code_size = round_to_page_size(self.text_section.len(), self.page_size);
if pc_loc_table_size + code_size > 0 {
unsafe {
Expand Down Expand Up @@ -394,8 +397,6 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {

/// Compiles the given executable, consuming the compiler
pub fn compile(mut self) -> Result<JitProgram, EbpfError> {
let text_section_base = self.result.text_section.as_ptr();

// Randomized padding at the start before random intervals begin
if self.config.noop_instruction_rate != 0 {
for _ in 0..self.diversification_rng.gen_range(0..MAX_START_PADDING_LENGTH) {
Expand All @@ -411,7 +412,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
return Err(EbpfError::ExhaustedTextSegment(self.pc));
}
let mut insn = ebpf::get_insn_unchecked(self.program, self.pc);
self.result.pc_section[self.pc] = unsafe { text_section_base.add(self.offset_in_text_section) } as usize;
self.result.pc_section[self.pc] = self.offset_in_text_section as u32;

// Regular instruction meter checkpoints to prevent long linear runs from exceeding their budget
if self.last_instruction_meter_validation_pc + self.config.instruction_meter_checkpoint_distance <= self.pc {
Expand All @@ -432,7 +433,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
ebpf::LD_DW_IMM if !self.executable.get_sbpf_version().disable_lddw() => {
self.emit_validate_and_profile_instruction_count(Some(self.pc + 2));
self.pc += 1;
self.result.pc_section[self.pc] = self.anchors[ANCHOR_CALL_UNSUPPORTED_INSTRUCTION] as usize;
self.result.pc_section[self.pc] = unsafe { self.anchors[ANCHOR_CALL_UNSUPPORTED_INSTRUCTION].offset_from(self.result.text_section.as_ptr()) as u32 };
ebpf::augment_lddw_unchecked(self.program, &mut insn);
if self.should_sanitize_constant(insn.imm) {
self.emit_sanitized_load_immediate(dst, insn.imm);
Expand Down Expand Up @@ -1550,29 +1551,31 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
self.emit_ins(X86Instruction::push(REGISTER_MAP[0], None));
// Calculate offset relative to program_vm_addr
self.emit_ins(X86Instruction::load_immediate(REGISTER_MAP[0], self.program_vm_addr as i64));
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x29, REGISTER_MAP[0], REGISTER_SCRATCH, None)); // guest_target_address -= self.program_vm_addr;
// Force alignment of guest_target_address
self.emit_ins(X86Instruction::alu_immediate(OperandSize::S64, 0x81, 4, REGISTER_SCRATCH, !(INSN_SIZE as i64 - 1), None)); // guest_target_address &= !(INSN_SIZE - 1);
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x29, REGISTER_MAP[0], REGISTER_SCRATCH, None)); // guest_target_pc = guest_target_address - self.program_vm_addr;
// Force alignment of guest_target_pc
self.emit_ins(X86Instruction::alu_immediate(OperandSize::S64, 0x81, 4, REGISTER_SCRATCH, !(INSN_SIZE as i64 - 1), None)); // guest_target_pc &= !(INSN_SIZE - 1);
// Bound check
// if(guest_target_address >= number_of_instructions * INSN_SIZE) throw CALL_OUTSIDE_TEXT_SEGMENT;
// if(guest_target_pc >= number_of_instructions * INSN_SIZE) throw CALL_OUTSIDE_TEXT_SEGMENT;
let number_of_instructions = self.result.pc_section.len();
self.emit_ins(X86Instruction::cmp_immediate(OperandSize::S64, REGISTER_SCRATCH, (number_of_instructions * INSN_SIZE) as i64, None)); // guest_target_address.cmp(number_of_instructions * INSN_SIZE)
self.emit_ins(X86Instruction::cmp_immediate(OperandSize::S64, REGISTER_SCRATCH, (number_of_instructions * INSN_SIZE) as i64, None)); // guest_target_pc.cmp(number_of_instructions * INSN_SIZE)
self.emit_ins(X86Instruction::conditional_jump_immediate(0x83, self.relative_to_anchor(ANCHOR_CALL_OUTSIDE_TEXT_SEGMENT, 6)));
// First half of self.emit_profile_instruction_count(None);
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x2b, REGISTER_INSTRUCTION_METER, RSP, Some(X86IndirectAccess::OffsetIndexShift(-8, RSP, 0)))); // instruction_meter -= guest_current_pc;
self.emit_ins(X86Instruction::alu_immediate(OperandSize::S64, 0x81, 5, REGISTER_INSTRUCTION_METER, 1, None)); // instruction_meter -= 1;
// Load host target_address from self.result.pc_section
debug_assert_eq!(INSN_SIZE, 8); // Because the instruction size is also the slot size we do not need to shift the offset
self.emit_ins(X86Instruction::load_immediate(REGISTER_MAP[0], self.result.pc_section.as_ptr() as i64)); // host_target_address = self.result.pc_section;
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, REGISTER_MAP[0], None)); // host_target_address += guest_target_address;
self.emit_ins(X86Instruction::load(OperandSize::S64, REGISTER_MAP[0], REGISTER_MAP[0], X86IndirectAccess::Offset(0))); // host_target_address = self.result.pc_section[host_target_address / 8];
// Calculate the guest_target_pc (dst / INSN_SIZE) to update REGISTER_INSTRUCTION_METER
// and as target_pc for potential ANCHOR_CALL_UNSUPPORTED_INSTRUCTION
let shift_amount = INSN_SIZE.trailing_zeros();
debug_assert_eq!(INSN_SIZE, 1 << shift_amount);
self.emit_ins(X86Instruction::alu_immediate(OperandSize::S64, 0xc1, 5, REGISTER_SCRATCH, shift_amount as i64, None)); // guest_target_pc /= INSN_SIZE;
// Second half of self.emit_profile_instruction_count(None);
// A version of `self.emit_profile_instruction_count(None);` which reads self.pc from the stack
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x2b, REGISTER_INSTRUCTION_METER, RSP, Some(X86IndirectAccess::OffsetIndexShift(-8, RSP, 0)))); // instruction_meter -= guest_current_pc;
self.emit_ins(X86Instruction::alu_immediate(OperandSize::S64, 0x81, 5, REGISTER_INSTRUCTION_METER, 1, None)); // instruction_meter -= 1;
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, REGISTER_INSTRUCTION_METER, None)); // instruction_meter += guest_target_pc;
// Load host_target_address offset from self.result.pc_section
self.emit_ins(X86Instruction::load_immediate(REGISTER_MAP[0], self.result.pc_section.as_ptr() as i64)); // host_target_address = self.result.pc_section;
self.emit_ins(X86Instruction::load(OperandSize::S32, REGISTER_MAP[0], REGISTER_MAP[0], X86IndirectAccess::OffsetIndexShift(0, REGISTER_SCRATCH, 2))); // host_target_address = self.result.pc_section[guest_target_pc];
// Offset host_target_address by self.result.text_section
self.emit_ins(X86Instruction::mov_mmx(OperandSize::S64, REGISTER_SCRATCH, MM0));
self.emit_ins(X86Instruction::load_immediate(REGISTER_SCRATCH, self.result.text_section.as_ptr() as i64)); // REGISTER_SCRATCH = self.result.text_section;
self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x01, REGISTER_SCRATCH, REGISTER_MAP[0], None)); // host_target_address += self.result.text_section;
self.emit_ins(X86Instruction::mov_mmx(OperandSize::S64, MM0, REGISTER_SCRATCH));
// Restore the clobbered REGISTER_MAP[0]
self.emit_ins(X86Instruction::xchg(OperandSize::S64, REGISTER_MAP[0], RSP, Some(X86IndirectAccess::OffsetIndexShift(0, RSP, 0)))); // Swap REGISTER_MAP[0] and host_target_address
self.emit_ins(X86Instruction::return_near()); // Tail call to host_target_address
Expand Down Expand Up @@ -1668,7 +1671,7 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
let instruction_end = unsafe { self.result.text_section.as_ptr().add(self.offset_in_text_section).add(instruction_length) };
let destination = if self.result.pc_section[target_pc] != 0 {
// Backward jump
self.result.pc_section[target_pc] as *const u8
&self.result.text_section[self.result.pc_section[target_pc] as usize] as *const u8
} else {
// Forward jump, needs relocation
self.text_section_jumps.push(Jump { location: unsafe { instruction_end.sub(4) }, target_pc });
Expand All @@ -1681,14 +1684,14 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
fn resolve_jumps(&mut self) {
// Relocate forward jumps
for jump in &self.text_section_jumps {
let destination = self.result.pc_section[jump.target_pc] as *const u8;
let destination = &self.result.text_section[self.result.pc_section[jump.target_pc] as usize] as *const u8;
let offset_value =
unsafe { destination.offset_from(jump.location) } as i32 // Relative jump
- mem::size_of::<i32>() as i32; // Jump from end of instruction
unsafe { ptr::write_unaligned(jump.location as *mut i32, offset_value); }
}
// Patch addresses to which `callx` may raise an unsupported instruction error
let call_unsupported_instruction = self.anchors[ANCHOR_CALL_REG_UNSUPPORTED_INSTRUCTION] as usize;
let call_unsupported_instruction = unsafe { self.anchors[ANCHOR_CALL_REG_UNSUPPORTED_INSTRUCTION].offset_from(self.result.text_section.as_ptr()) as u32 };
if self.executable.get_sbpf_version().static_syscalls() {
let mut prev_pc = 0;
for current_pc in self.executable.get_function_registry().keys() {
Expand Down

0 comments on commit 085a488

Please sign in to comment.