Skip to content

Commit 3dfc0af

Browse files
committed
Use pointers in token reader deserializations
Instead of invalidating references Reddit credit: https://www.reddit.com/r/rust/comments/18oe075/comment/kehs8jq/?utm_source=share&utm_medium=web2x&context=3 This allows us to re-enable miri stacked borrow testing, which found an invalid reference in the binary tape parsing.
1 parent c059ee3 commit 3dfc0af

File tree

8 files changed

+86
-83
lines changed

8 files changed

+86
-83
lines changed

.github/workflows/rust.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jobs:
8080
run: |
8181
rustup toolchain install nightly --component miri
8282
cargo miri setup
83-
MIRIFLAGS="-Zmiri-disable-stacked-borrows" cargo miri test
83+
cargo miri test
8484
8585
- name: Compile fuzz
8686
if: matrix.build == 'nightly'

src/binary/de.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ impl<'a, 'de, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::Deser
4848
where
4949
V: Visitor<'de>,
5050
{
51-
visitor.visit_map(BinaryReaderMap::new(self, true))
51+
let me = std::ptr::addr_of!(self);
52+
visitor.visit_map(BinaryReaderMap::new(me, true))
5253
}
5354

5455
fn deserialize_struct<V>(
@@ -71,12 +72,12 @@ impl<'a, 'de, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::Deser
7172
}
7273

7374
struct BinaryReaderMap<'a: 'a, 'res, RES: 'a, F, R> {
74-
de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
75+
de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
7576
root: bool,
7677
}
7778

7879
impl<'a, 'res, RES: 'a, F, R> BinaryReaderMap<'a, 'res, RES, F, R> {
79-
fn new(de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>, root: bool) -> Self {
80+
fn new(de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>, root: bool) -> Self {
8081
BinaryReaderMap { de, root }
8182
}
8283
}
@@ -91,20 +92,23 @@ impl<'de, 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> MapAccess
9192
where
9293
K: DeserializeSeed<'de>,
9394
{
94-
let de = unsafe { &mut *(self.de as *mut _) };
9595
loop {
96-
match self.de.reader.next() {
96+
match unsafe { self.de.read() }.reader.next() {
9797
Ok(Some(Token::Close)) => return Ok(None),
9898
Ok(Some(Token::Open)) => {
99-
let _ = self.de.reader.read();
99+
let _ = unsafe { self.de.read() }.reader.read();
100100
}
101101
Ok(Some(token)) => {
102102
return seed
103-
.deserialize(BinaryReaderTokenDeserializer { de, token })
103+
.deserialize(BinaryReaderTokenDeserializer { de: self.de, token })
104104
.map(Some)
105105
}
106106
Ok(None) if self.root => return Ok(None),
107-
Ok(None) => return Err(LexError::Eof.at(self.de.reader.position()).into()),
107+
Ok(None) => {
108+
return Err(LexError::Eof
109+
.at(unsafe { self.de.read() }.reader.position())
110+
.into())
111+
}
108112
Err(e) => return Err(e.into()),
109113
}
110114
}
@@ -115,18 +119,17 @@ impl<'de, 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> MapAccess
115119
where
116120
V: DeserializeSeed<'de>,
117121
{
118-
let de = unsafe { &mut *(self.de as *mut _) };
119-
let mut token = self.de.reader.read()?;
122+
let mut token = unsafe { self.de.read() }.reader.read()?;
120123
if matches!(token, Token::Equal) {
121-
token = self.de.reader.read()?;
124+
token = unsafe { self.de.read() }.reader.read()?;
122125
}
123126

124-
seed.deserialize(BinaryReaderTokenDeserializer { de, token })
127+
seed.deserialize(BinaryReaderTokenDeserializer { de: self.de, token })
125128
}
126129
}
127130

128131
struct BinaryReaderTokenDeserializer<'a, 'res, RES: 'a, F, R> {
129-
de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
132+
de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
130133
token: Token<'a>,
131134
}
132135

@@ -148,18 +151,22 @@ where
148151
Token::I32(x) => visitor.visit_i32(x),
149152
Token::Bool(x) => visitor.visit_bool(x),
150153
Token::Quoted(x) | Token::Unquoted(x) => {
151-
match self.de.config.flavor.decode(x.as_bytes()) {
154+
match unsafe { self.de.read() }.config.flavor.decode(x.as_bytes()) {
152155
Cow::Borrowed(x) => visitor.visit_str(x),
153156
Cow::Owned(x) => visitor.visit_string(x),
154157
}
155158
}
156-
Token::F32(x) => visitor.visit_f32(self.de.config.flavor.visit_f32(x)),
157-
Token::F64(x) => visitor.visit_f64(self.de.config.flavor.visit_f64(x)),
159+
Token::F32(x) => {
160+
visitor.visit_f32(unsafe { self.de.read() }.config.flavor.visit_f32(x))
161+
}
162+
Token::F64(x) => {
163+
visitor.visit_f64(unsafe { self.de.read() }.config.flavor.visit_f64(x))
164+
}
158165
Token::Rgb(x) => visitor.visit_seq(ColorSequence::new(x)),
159166
Token::I64(x) => visitor.visit_i64(x),
160-
Token::Id(s) => match self.de.config.resolver.resolve(s) {
167+
Token::Id(s) => match unsafe { self.de.read() }.config.resolver.resolve(s) {
161168
Some(id) => visitor.visit_borrowed_str(id),
162-
None => match self.de.config.failed_resolve_strategy {
169+
None => match unsafe { self.de.read() }.config.failed_resolve_strategy {
163170
FailedResolveStrategy::Error => Err(Error::from(DeserializeError {
164171
kind: DeserializeErrorKind::UnknownToken { token_id: s },
165172
})),
@@ -171,11 +178,11 @@ where
171178
},
172179
Token::Close => Err(Error::invalid_syntax(
173180
"did not expect end",
174-
self.de.reader.position(),
181+
unsafe { self.de.read() }.reader.position(),
175182
)),
176183
Token::Equal => Err(Error::invalid_syntax(
177184
"did not expect equal",
178-
self.de.reader.position(),
185+
unsafe { self.de.read() }.reader.position(),
179186
)),
180187
Token::Open => visitor.visit_seq(BinaryReaderSeq::new(self.de)),
181188
}
@@ -286,7 +293,7 @@ impl<'a, 'de: 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::D
286293
V: Visitor<'de>,
287294
{
288295
if let Token::F32(x) = &self.token {
289-
visitor.visit_f32(self.de.config.flavor.visit_f32(*x))
296+
visitor.visit_f32(unsafe { self.de.read() }.config.flavor.visit_f32(*x))
290297
} else {
291298
self.deser(visitor)
292299
}
@@ -298,7 +305,7 @@ impl<'a, 'de: 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::D
298305
V: Visitor<'de>,
299306
{
300307
if let Token::F64(x) = &self.token {
301-
visitor.visit_f64(self.de.config.flavor.visit_f64(*x))
308+
visitor.visit_f64(unsafe { self.de.read() }.config.flavor.visit_f64(*x))
302309
} else {
303310
self.deser(visitor)
304311
}
@@ -319,7 +326,7 @@ impl<'a, 'de: 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::D
319326
{
320327
match self.token {
321328
Token::Quoted(x) | Token::Unquoted(x) => {
322-
match self.de.config.flavor.decode(x.as_bytes()) {
329+
match unsafe { self.de.read() }.config.flavor.decode(x.as_bytes()) {
323330
Cow::Borrowed(x) => visitor.visit_str(x),
324331
Cow::Owned(x) => visitor.visit_string(x),
325332
}
@@ -380,10 +387,10 @@ impl<'a, 'de: 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::D
380387
if !seq.hit_end {
381388
// For when we are deserializing an array that doesn't read
382389
// the closing token
383-
if !matches!(self.de.reader.read()?, Token::Close) {
390+
if !matches!(unsafe { self.de.read() }.reader.read()?, Token::Close) {
384391
return Err(Error::invalid_syntax(
385392
"Expected sequence to be terminated with an end token",
386-
self.de.reader.position(),
393+
unsafe { self.de.read() }.reader.position(),
387394
));
388395
}
389396
}
@@ -459,20 +466,20 @@ impl<'a, 'de: 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> de::D
459466
V: Visitor<'de>,
460467
{
461468
if matches!(self.token, Token::Open) {
462-
self.de.reader.skip_container()?;
469+
unsafe { self.de.read() }.reader.skip_container()?;
463470
}
464471

465472
visitor.visit_unit()
466473
}
467474
}
468475

469476
struct BinaryReaderSeq<'a: 'a, 'res, RES: 'a, F, R> {
470-
de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
477+
de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
471478
hit_end: bool,
472479
}
473480

474481
impl<'a, 'de: 'a, 'res: 'de, RES: 'a, F, R> BinaryReaderSeq<'a, 'res, RES, F, R> {
475-
fn new(de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>) -> Self {
482+
fn new(de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>) -> Self {
476483
BinaryReaderSeq { de, hit_end: false }
477484
}
478485
}
@@ -486,26 +493,25 @@ impl<'de, 'a, 'res: 'de, RES: TokenResolver, F: BinaryFlavor, R: Read> SeqAccess
486493
where
487494
T: DeserializeSeed<'de>,
488495
{
489-
let de = unsafe { &mut *(self.de as *mut _) };
490-
match self.de.reader.read()? {
496+
match unsafe { self.de.read() }.reader.read()? {
491497
Token::Close => {
492498
self.hit_end = true;
493499
Ok(None)
494500
}
495501
token => seed
496-
.deserialize(BinaryReaderTokenDeserializer { de, token })
502+
.deserialize(BinaryReaderTokenDeserializer { de: self.de, token })
497503
.map(Some),
498504
}
499505
}
500506
}
501507

502508
struct BinaryReaderEnum<'a, 'res, RES: 'a, F, R> {
503-
de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
509+
de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>,
504510
token: Token<'a>,
505511
}
506512

507513
impl<'a, 'res, RES: 'a, F, R> BinaryReaderEnum<'a, 'res, RES, F, R> {
508-
fn new(de: &'a mut BinaryReaderDeserializer<'res, RES, F, R>, token: Token<'a>) -> Self {
514+
fn new(de: *const &'a mut BinaryReaderDeserializer<'res, RES, F, R>, token: Token<'a>) -> Self {
509515
BinaryReaderEnum { de, token }
510516
}
511517
}

src/binary/reader.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,10 @@ where
222222
/// ```
223223
#[inline]
224224
pub fn read(&mut self) -> Result<Token, ReaderError> {
225-
// Workaround for borrow checker :(
226-
let s = unsafe { &mut *(self as *mut TokenReader<R>) };
225+
let s = std::ptr::addr_of!(self);
227226
match self.next_opt() {
228227
(Some(x), _) => Ok(x),
229-
(None, None) => Err(s.lex_error(LexError::Eof)),
228+
(None, None) => Err(unsafe { (*s).lex_error(LexError::Eof) }),
230229
(None, Some(e)) => Err(e),
231230
}
232231
}

src/binary/tape.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ impl<'a, 'b> ParserState<'a, 'b> {
263263
macro_rules! push_end {
264264
() => {
265265
let end_idx = self.token_tape.len();
266-
match unsafe { self.token_tape.get_unchecked_mut(parent_ind) } {
267-
BinaryToken::Array(end) | BinaryToken::Object(end) => {
266+
match self.token_tape.get_mut(parent_ind) {
267+
Some(BinaryToken::Array(end) | BinaryToken::Object(end)) => {
268268
let grand_ind = *end;
269269
*end = end_idx;
270270
let val = BinaryToken::End(parent_ind);

src/buffer.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,17 @@ impl BufferWindow {
8282
/// start.
8383
#[inline]
8484
pub fn fill_buf(&mut self, mut reader: impl Read) -> Result<usize, BufferError> {
85-
// No buffer means we are reading from a slice and there is nothing more
86-
// to fill
87-
if self.buf.len() == 0 {
85+
let carry_over = self.window_len();
86+
if carry_over >= self.buf.len() {
8887
return Ok(0);
8988
}
9089

9190
// Copy over the unconsumed bytes to the start of the buffer
92-
let carry_over = self.window_len();
9391
if carry_over != 0 {
9492
if carry_over >= self.buf.len() {
9593
return Err(BufferError::BufferFull);
9694
}
97-
unsafe { self.start.copy_to(self.buf.as_mut_ptr(), carry_over) };
95+
self.buf.copy_within(self.consumed_data().., 0);
9896
}
9997

10098
self.prior_reads += self.consumed_data();
@@ -104,6 +102,7 @@ impl BufferWindow {
104102
// Have the reader start filling in bytes after unconsumed bytes
105103
match reader.read(&mut self.buf[carry_over..]) {
106104
Ok(r) => {
105+
self.start = self.buf.as_ptr();
107106
self.end = unsafe { self.end.add(r) };
108107
Ok(r)
109108
}

0 commit comments

Comments
 (0)