Skip to content

Commit

Permalink
Merge pull request #131 from rakaly/deser-hint
Browse files Browse the repository at this point in the history
Allow text deserialization on incorrect hint
  • Loading branch information
nickbabcock authored Nov 15, 2023
2 parents 8b43ba6 + 02c006a commit 5bc5a55
Showing 1 changed file with 135 additions and 42 deletions.
177 changes: 135 additions & 42 deletions src/text/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ where
_ => None,
}
}

fn read_scalar(&self) -> Result<crate::Scalar<'de>, Error> {
self.reader().read_scalar().map_err(Error::from)
}
}

macro_rules! deserialize_any_value {
Expand Down Expand Up @@ -524,21 +528,34 @@ where
where
V: Visitor<'de>,
{
visit_str!(self.reader().read_str()?, visitor)
if let Ok(x) = self.reader().read_str() {
visit_str!(x, visitor)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_borrowed_bytes(self.reader().read_scalar()?.as_bytes())
if let Ok(scalar) = self.read_scalar() {
visitor.visit_borrowed_bytes(scalar.as_bytes())
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_byte_buf(self.reader().read_scalar()?.as_bytes().to_vec())
let val = self.read_scalar().map(|x| x.as_bytes().to_vec());
if let Ok(x) = val {
visitor.visit_byte_buf(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand All @@ -552,14 +569,26 @@ where
where
V: Visitor<'de>,
{
visitor.visit_string(self.reader().read_string()?)
if let Ok(x) = self.reader().read_string() {
visitor.visit_string(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_bool(self.reader().read_scalar()?.to_bool()?)
let val = self
.read_scalar()
.and_then(|x| x.to_bool().map_err(Error::from));

if let Ok(x) = val {
visitor.visit_bool(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -587,7 +616,15 @@ where
where
V: Visitor<'de>,
{
visitor.visit_i64(self.reader().read_scalar()?.to_i64()?)
let val = self
.read_scalar()
.and_then(|x| x.to_i64().map_err(Error::from));

if let Ok(x) = val {
visitor.visit_i64(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -622,7 +659,15 @@ where
where
V: Visitor<'de>,
{
visitor.visit_u64(self.reader().read_scalar()?.to_u64()?)
let val = self
.read_scalar()
.and_then(|x| x.to_u64().map_err(Error::from));

if let Ok(x) = val {
visitor.visit_u64(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand All @@ -636,7 +681,15 @@ where
where
V: Visitor<'de>,
{
visitor.visit_f64(self.reader().read_scalar()?.to_f64()?)
let val = self
.read_scalar()
.and_then(|x: crate::Scalar<'_>| x.to_f64().map_err(Error::from));

if let Ok(x) = val {
visitor.visit_f64(x)
} else {
self.deserialize_any(visitor)
}
}

fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand All @@ -650,50 +703,33 @@ where
where
V: Visitor<'de>,
{
if let Some(x) = self.value_reader() {
if matches!(x.token(), TextToken::Unquoted(_) | TextToken::Quoted(_)) {
return visit_str!(x.read_str()?, visitor);
}

let map = MapAccess::new(x.read_object()?, self.config);
let val = self.value_reader().and_then(|x| x.read_object().ok());
if let Some(x) = val {
let map = MapAccess::new(x, self.config);
visitor.visit_map(map)
} else {
Err(Error::from(DeserializeError {
kind: DeserializeErrorKind::Unsupported(String::from(
"can only deserialize an object as a map",
)),
}))
self.deserialize_any(visitor)
}
}

fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.reader() {
Reader::Value(x) => {
if matches!(x.token(), TextToken::Unquoted(_) | TextToken::Quoted(_)) {
return visit_str!(x.read_str()?, visitor);
}
let arr = match self.reader() {
Reader::Array(x) => Some(x),
Reader::Value(x) => x.read_array().ok(),
_ => None,
};

let map = SeqAccess {
config: self.config,
values: x.read_array()?.values(),
};
visitor.visit_seq(map)
}
Reader::Array(x) => {
let map = SeqAccess {
config: self.config,
values: x.values(),
};
visitor.visit_seq(map)
}
_ => Err(Error::from(DeserializeError {
kind: DeserializeErrorKind::Unsupported(String::from(
"unexpected reader for sequence",
)),
})),
if let Some(x) = arr {
let map = SeqAccess {
config: self.config,
values: x.values(),
};
visitor.visit_seq(map)
} else {
self.deserialize_any(visitor)
}
}

Expand Down Expand Up @@ -2043,6 +2079,63 @@ mod tests {
}
}

#[test]
fn test_deserialize_i32_hint() {
#[derive(Deserialize, Debug, PartialEq)]
struct MyStruct {
field1: MaybeI32,
field2: MaybeI32,
}

#[derive(Debug, PartialEq)]
enum MaybeI32 {
Val(i32),
Str(String),
}

struct MaybeI32Visitor;
impl<'de> de::Visitor<'de> for MaybeI32Visitor {
type Value = MaybeI32;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("test case")
}

fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(MaybeI32::Val(v as i32))
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(MaybeI32::Str(String::from(v)))
}
}

impl<'de> Deserialize<'de> for MaybeI32 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_i32(MaybeI32Visitor)
}
}

let data = br#"field1=1 field2=invalid"#;
let actual: MyStruct = from_slice(&data[..]).unwrap();
assert_eq!(
actual,
MyStruct {
field1: MaybeI32::Val(1),
field2: MaybeI32::Str(String::from("invalid")),
}
);
}

#[test]
fn test_deserialize_untagged() {
#[derive(Deserialize, Debug, PartialEq)]
Expand Down

0 comments on commit 5bc5a55

Please sign in to comment.