Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 85 additions & 24 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ struct Decoder<'a> {
metastack: Vec<Vec<PickleValue>>,
/// Tracks which memo indices are bound to each stack slot (parallel to stack).
/// When BINPUT stores from stack top, the memo index is recorded here.
/// After in-place mutations (SETITEMS, APPENDS, etc.), stale memo entries
/// are updated via sync_memo_top().
stack_memo: Vec<Vec<usize>>,
/// Saved stack_memo during MARK (parallel to metastack).
meta_stack_memo: Vec<Vec<Vec<usize>>>,
/// Dirty flags parallel to memo: true means the memo entry is stale
/// (the owning stack slot was mutated after BINPUT stored the value).
/// Resolved lazily at BINGET or eagerly when the slot is popped.
dirty_memo: Vec<bool>,
}

impl<'a> Decoder<'a> {
Expand All @@ -54,6 +56,7 @@ impl<'a> Decoder<'a> {
metastack: Vec::with_capacity(4),
stack_memo: Vec::with_capacity(16),
meta_stack_memo: Vec::with_capacity(4),
dirty_memo: Vec::with_capacity(16),
}
}

Expand Down Expand Up @@ -310,7 +313,7 @@ impl<'a> Decoder<'a> {
));
}
}
self.sync_memo_top()?;
self.mark_top_dirty();
}
APPENDS => {
let items = self.pop_mark()?;
Expand All @@ -335,7 +338,7 @@ impl<'a> Decoder<'a> {
));
}
}
self.sync_memo_top()?;
self.mark_top_dirty();
}

// -- Dict --
Expand Down Expand Up @@ -369,7 +372,7 @@ impl<'a> Decoder<'a> {
));
}
}
self.sync_memo_top()?;
self.mark_top_dirty();
}
SETITEMS => {
let items = self.pop_mark()?;
Expand All @@ -395,7 +398,7 @@ impl<'a> Decoder<'a> {
));
}
}
self.sync_memo_top()?;
self.mark_top_dirty();
}

// -- Set/FrozenSet (protocol 4) --
Expand All @@ -410,7 +413,7 @@ impl<'a> Decoder<'a> {
"ADDITEMS on non-set".to_string(),
));
}
self.sync_memo_top()?;
self.mark_top_dirty();
}
FROZENSET => {
let items = self.pop_mark()?;
Expand Down Expand Up @@ -609,7 +612,7 @@ impl<'a> Decoder<'a> {
if let Some(top_bindings) = self.stack_memo.last_mut() {
top_bindings.extend(obj_bindings);
}
self.sync_memo_top()?;
self.mark_top_dirty();
}
NEWOBJ => {
let args = self.pop_value()?;
Expand Down Expand Up @@ -783,8 +786,29 @@ impl<'a> Decoder<'a> {

#[inline]
fn pop_value(&mut self) -> Result<PickleValue, CodecError> {
self.stack_memo.pop();
self.stack.pop().ok_or(CodecError::StackUnderflow)
let bindings = self.stack_memo.pop().unwrap_or_default();
let val = self.stack.pop().ok_or(CodecError::StackUnderflow)?;
// Sync any dirty memo entries before the value leaves the stack
if !bindings.is_empty() {
let mut need_clone = false;
for &idx in &bindings {
if idx < self.dirty_memo.len() && self.dirty_memo[idx] {
need_clone = true;
break;
}
}
if need_clone {
for &idx in &bindings {
if idx < self.dirty_memo.len() && self.dirty_memo[idx] {
if idx < self.memo.len() {
self.memo[idx] = val.clone();
}
self.dirty_memo[idx] = false;
}
}
}
}
Ok(val)
}

#[inline]
Expand All @@ -802,9 +826,19 @@ impl<'a> Decoder<'a> {
// Take the current stack (everything since MARK) as the result.
// This is a pointer swap — no element-by-element drain needed.
let items = std::mem::take(&mut self.stack);
let slot_memos = std::mem::take(&mut self.stack_memo);

// Also discard the memo bindings for popped items
self.stack_memo.clear();
// Sync dirty memo entries for all popped slots before values are consumed
for (val, bindings) in items.iter().zip(slot_memos.iter()) {
for &idx in bindings {
if idx < self.dirty_memo.len() && self.dirty_memo[idx] {
if idx < self.memo.len() {
self.memo[idx] = val.clone();
}
self.dirty_memo[idx] = false;
}
}
}

// Restore the previous stack from metastack
if let Some(old_stack) = self.metastack.pop() {
Expand All @@ -825,18 +859,48 @@ impl<'a> Decoder<'a> {
}
if idx >= self.memo.len() {
self.memo.resize(idx + 1, PickleValue::None);
self.dirty_memo.resize(idx + 1, false);
}
self.memo[idx] = val;
self.dirty_memo[idx] = false;
Ok(())
}

fn memo_get(&self, idx: usize) -> Result<PickleValue, CodecError> {
/// Get a memo entry, lazily resolving dirty (stale) entries first.
fn memo_get(&mut self, idx: usize) -> Result<PickleValue, CodecError> {
if idx < self.dirty_memo.len() && self.dirty_memo[idx] {
self.resolve_dirty_memo(idx);
}
self.memo
.get(idx)
.cloned()
.ok_or_else(|| CodecError::InvalidData(format!("memo index {idx} not found")))
}

/// Resolve a dirty memo entry by finding its live value on the stack.
fn resolve_dirty_memo(&mut self, memo_idx: usize) {
// Search current stack for the slot that owns this memo binding
for (si, bindings) in self.stack_memo.iter().enumerate() {
if bindings.contains(&memo_idx) {
self.memo[memo_idx] = self.stack[si].clone();
self.dirty_memo[memo_idx] = false;
return;
}
}
// Search metastack (values saved by MARK)
for (mi, meta_sm) in self.meta_stack_memo.iter().enumerate() {
for (si, bindings) in meta_sm.iter().enumerate() {
if bindings.contains(&memo_idx) {
self.memo[memo_idx] = self.metastack[mi][si].clone();
self.dirty_memo[memo_idx] = false;
return;
}
}
}
// Value was already consumed from stack — memo has the last stored value
self.dirty_memo[memo_idx] = false;
}

/// Record that the current stack top was stored in memo at `idx`.
#[inline]
fn record_memo_binding(&mut self, idx: usize) {
Expand All @@ -845,21 +909,18 @@ impl<'a> Decoder<'a> {
}
}

/// After an in-place mutation of the stack top (SETITEMS, APPENDS, etc.),
/// update any memo entries that were cloned from the pre-mutation state.
fn sync_memo_top(&mut self) -> Result<(), CodecError> {
/// Mark memo entries bound to the current stack top as dirty (stale).
/// Called after in-place mutations (SETITEMS, APPENDS, etc.) instead of
/// eagerly cloning. Resolution is deferred to memo_get() or pop_value().
#[inline]
fn mark_top_dirty(&mut self) {
if let Some(bindings) = self.stack_memo.last() {
if !bindings.is_empty() {
let new_val = self.stack.last()
.ok_or(CodecError::StackUnderflow)?.clone();
for &idx in bindings {
if idx < self.memo.len() {
self.memo[idx] = new_val.clone();
}
for &idx in bindings {
if idx < self.dirty_memo.len() {
self.dirty_memo[idx] = true;
}
}
}
Ok(())
}
}

Expand Down