From 147e7f097fa0a00b9da0170fa9dd299e10c5eef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20W=C3=BCnsche?= Date: Tue, 7 May 2024 19:13:49 +0200 Subject: [PATCH] buffer: fix creeping nulls This happened in the CI and on the most recent rust compiler on my system in in debug mode. Previously a lot of .is_null()'s were distributed over the code this commit cleans this up and by-chance fixed whatever bug in the code flow was occurring in this case. --- betree/src/buffer.rs | 110 ++++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/betree/src/buffer.rs b/betree/src/buffer.rs index ad353371..b8308d34 100644 --- a/betree/src/buffer.rs +++ b/betree/src/buffer.rs @@ -16,7 +16,8 @@ use std::{ fmt, io, mem::ManuallyDrop, ops::{Deref, Range}, - ptr, slice, + ptr::NonNull, + slice, sync::Arc, }; @@ -54,24 +55,29 @@ fn split_range_at( #[derive(Debug)] struct AlignedStorage { - ptr: *mut u8, + ptr: NonNull, capacity: Block, } -impl Default for AlignedStorage { - fn default() -> Self { - AlignedStorage { - ptr: ptr::null_mut(), - capacity: Block(0), - } - } -} +// impl Default for AlignedStorage { +// fn default() -> Self { +// AlignedStorage { +// ptr: None, +// capacity: Block(0), +// } +// } +// } impl AlignedStorage { fn zeroed(capacity: Block) -> Self { - let mut storage = Self::default(); - storage.ensure_capacity(capacity); - storage + Self { + ptr: unsafe { + let new_layout = + Layout::from_size_align_unchecked(capacity.to_bytes() as usize, BLOCK_SIZE); + NonNull::new(alloc::alloc_zeroed(new_layout)).expect("Allocation failed.") + }, + capacity, + } } fn ensure_capacity(&mut self, requested_capacity: Block) { @@ -101,25 +107,21 @@ impl AlignedStorage { // NOTE: this might not call calloc as initially thought. The default impl just allocs uninitialised // memory, and then writes 0 to it - let new_ptr = if self.ptr.is_null() { - alloc::alloc_zeroed(new_layout) - } else { - let realloc_ptr = - alloc::realloc(self.ptr, curr_layout, wanted_capacity.to_bytes() as usize); - if realloc_ptr.is_null() { - let new_ptr = alloc::alloc_zeroed(new_layout); - if !self.ptr.is_null() { - self.ptr - .copy_to_nonoverlapping(new_ptr, self.capacity.to_bytes() as usize); - alloc::dealloc(self.ptr, curr_layout); - } - new_ptr - } else { - realloc_ptr - } - }; - - self.ptr = new_ptr; + let realloc_ptr = alloc::realloc( + self.ptr.as_ptr(), + curr_layout, + wanted_capacity.to_bytes() as usize, + ); + + self.ptr = NonNull::new(realloc_ptr).unwrap_or_else(|| { + let new_ptr = + NonNull::new(alloc::alloc_zeroed(new_layout)).expect("Allocation failed."); + self.ptr + .as_ptr() + .copy_to_nonoverlapping(new_ptr.as_ptr(), self.capacity.to_bytes() as usize); + alloc::dealloc(self.ptr.as_ptr(), curr_layout); + new_ptr + }); self.capacity = wanted_capacity; } } @@ -127,14 +129,10 @@ impl AlignedStorage { impl Drop for AlignedStorage { fn drop(&mut self) { - if !self.ptr.is_null() { - unsafe { - let layout = Layout::from_size_align_unchecked( - self.capacity.to_bytes() as usize, - BLOCK_SIZE, - ); - alloc::dealloc(self.ptr, layout) - } + unsafe { + let layout = + Layout::from_size_align_unchecked(self.capacity.to_bytes() as usize, BLOCK_SIZE); + alloc::dealloc(self.ptr.as_ptr(), layout) } } } @@ -147,7 +145,9 @@ impl From> for AlignedStorage { if is_aligned(&b) { AlignedStorage { capacity: Block::from_bytes(b.len() as u32), - ptr: unsafe { (*Box::into_raw(b)).as_mut_ptr() }, + ptr: unsafe { + NonNull::new((*Box::into_raw(b)).as_mut_ptr()).expect("Assume valid pointer.") + }, } } else { assert!( @@ -157,8 +157,15 @@ impl From> for AlignedStorage { log::warn!("Unaligned buffer, copying {} bytes", b.len()); let size = Block::round_up_from_bytes(b.len() as u32); let storage = AlignedStorage::zeroed(size); + let align = std::mem::align_of::(); + assert!(!b.as_ptr().is_null()); + assert!(storage.ptr.as_ptr().align_offset(align) == 0); + assert!(b.as_ptr().align_offset(align) == 0); unsafe { - storage.ptr.copy_from_nonoverlapping(b.as_ptr(), b.len()); + storage + .ptr + .as_ptr() + .copy_from_nonoverlapping(b.as_ptr(), b.len()); } storage } @@ -281,6 +288,7 @@ impl io::Write for BufWrite { unsafe { self.buf .ptr + .as_ptr() .offset(self.size as isize) .copy_from_nonoverlapping(data.as_ptr(), data.len()); self.size = required_size; @@ -304,7 +312,7 @@ unsafe impl zstd::stream::raw::WriteBuf for BufWrite { } fn as_mut_ptr(&mut self) -> *mut u8 { - self.buf.ptr + unsafe { self.buf.ptr.as_mut() } } unsafe fn filled_until(&mut self, n: usize) { @@ -333,7 +341,8 @@ impl io::Seek for BufWrite { impl AsRef<[u8]> for BufWrite { fn as_ref(&self) -> &[u8] { unsafe { - let slice = slice::from_raw_parts(self.buf.ptr, self.buf.capacity.to_bytes() as usize); + let slice = + slice::from_raw_parts(self.buf.ptr.as_ptr(), self.buf.capacity.to_bytes() as usize); &slice[..self.size as usize] } } @@ -342,8 +351,10 @@ impl AsRef<[u8]> for BufWrite { impl AsMut<[u8]> for BufWrite { fn as_mut(&mut self) -> &mut [u8] { unsafe { - let slice = - slice::from_raw_parts_mut(self.buf.ptr, self.buf.capacity.to_bytes() as usize); + let slice = slice::from_raw_parts_mut( + self.buf.ptr.as_ptr(), + self.buf.capacity.to_bytes() as usize, + ); &mut slice[..self.size as usize] } } @@ -402,7 +413,7 @@ impl Buf { unsafe { Box::from_raw(slice::from_raw_parts_mut( - storage.ptr, + storage.ptr.as_ptr(), storage.capacity.to_bytes() as usize, )) } @@ -476,7 +487,7 @@ impl AsRef<[u8]> for Buf { let start = self.range.start.to_bytes() as usize; let end = self.range.end.to_bytes() as usize; let buf = &*self.buf.buf.get(); - let slice = slice::from_raw_parts(buf.ptr, buf.capacity.to_bytes() as usize); + let slice = slice::from_raw_parts(buf.ptr.as_ptr(), buf.capacity.to_bytes() as usize); &slice[start..end] } } @@ -495,7 +506,8 @@ impl AsMut<[u8]> for MutBuf { let start = self.range.start.to_bytes() as usize; let end = self.range.end.to_bytes() as usize; let buf = &*self.buf.buf.get(); - let slice = slice::from_raw_parts_mut(buf.ptr, buf.capacity.to_bytes() as usize); + let slice = + slice::from_raw_parts_mut(buf.ptr.as_ptr(), buf.capacity.to_bytes() as usize); &mut slice[start..end] } }