From 8e6cdc9f021564bb27512f70ad93ca8628f2b6c3 Mon Sep 17 00:00:00 2001 From: Matthias Wahl Date: Tue, 27 Oct 2020 15:08:47 +0100 Subject: [PATCH 1/3] Add custom de-/serializers to Method in order to serialize it to string. --- src/method.rs | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/method.rs b/src/method.rs index 0d239891..52718f30 100644 --- a/src/method.rs +++ b/src/method.rs @@ -1,3 +1,5 @@ +use serde::de::{Error as DeError, Unexpected}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use std::fmt::{self, Display}; use std::str::FromStr; @@ -50,6 +52,44 @@ impl Method { } } +struct MethodVisitor; + +impl<'de> Visitor<'de> for MethodVisitor { + type Value = Method; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a HTTP method &str") + } + + fn visit_str(self, v: &str) -> Result + where + E: DeError, + { + match Method::from_str(v) { + Ok(method) => Ok(method), + Err(_) => Err(DeError::invalid_value(Unexpected::Str(v), &self)), + } + } +} + +impl<'de> Deserialize<'de> for Method { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(MethodVisitor) + } +} + +impl Serialize for Method { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + impl Display for Method { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -108,3 +148,20 @@ impl AsRef for Method { } } } + +#[cfg(test)] +mod test { + use super::Method; + + #[test] + fn serde() -> Result<(), serde_json::Error> { + assert_eq!(Method::Get, serde_json::from_str("\"GET\"")?); + assert_eq!(Some("PATCH"), serde_json::to_value(Method::Patch)?.as_str()); + Ok(()) + } + #[test] + fn serde_fail() -> Result<(), serde_json::Error> { + serde_json::from_str::("\"ABC\"").expect_err("Did deserialize from invalid string"); + Ok(()) + } +} From 2dfe020589719f00cab5c02b8e9e7605aedeb19b Mon Sep 17 00:00:00 2001 From: Matthias Wahl Date: Tue, 27 Oct 2020 15:09:35 +0100 Subject: [PATCH 2/3] Add custom de-/serializers to StatusCode in order to serialize it to u16 and deserialize it from most int types. --- src/status_code.rs | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/src/status_code.rs b/src/status_code.rs index 5412c092..749e9128 100644 --- a/src/status_code.rs +++ b/src/status_code.rs @@ -1,3 +1,5 @@ +use serde::de::{Error as DeError, Unexpected, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt::{self, Display}; /// HTTP response status codes. @@ -537,6 +539,84 @@ impl StatusCode { } } +impl Serialize for StatusCode { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let value: u16 = *self as u16; + serializer.serialize_u16(value) + } +} + +struct StatusCodeU16Visitor; + +impl<'de> Visitor<'de> for StatusCodeU16Visitor { + type Value = StatusCode; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a u16 representing the status code") + } + + fn visit_i16(self, v: i16) -> Result + where + E: DeError, + { + self.visit_u16(v as u16) + } + + fn visit_i32(self, v: i32) -> Result + where + E: DeError, + { + self.visit_u16(v as u16) + } + + fn visit_i64(self, v: i64) -> Result + where + E: DeError, + { + self.visit_u16(v as u16) + } + + fn visit_u16(self, v: u16) -> Result + where + E: DeError, + { + use std::convert::TryFrom; + match StatusCode::try_from(v) { + Ok(status_code) => Ok(status_code), + Err(_) => Err(DeError::invalid_value( + Unexpected::Unsigned(v as u64), + &self, + )), + } + } + + fn visit_u32(self, v: u32) -> Result + where + E: DeError, + { + self.visit_u16(v as u16) + } + + fn visit_u64(self, v: u64) -> Result + where + E: DeError, + { + self.visit_u16(v as u16) + } +} + +impl<'de> Deserialize<'de> for StatusCode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(StatusCodeU16Visitor) + } +} + impl From for u16 { fn from(code: StatusCode) -> u16 { code as u16 @@ -629,3 +709,18 @@ impl Display for StatusCode { write!(f, "{}", *self as u16) } } + +#[cfg(test)] +mod test { + use super::StatusCode; + #[test] + fn serde_as_u16() -> Result<(), serde_json::Error> { + let status_code: StatusCode = serde_json::from_str("202")?; + assert_eq!(StatusCode::Accepted, status_code); + assert_eq!( + Some(202), + serde_json::to_value(&StatusCode::Accepted)?.as_u64() + ); + Ok(()) + } +} From ad513b59d5ad3ac450d791f41ad074042d259be4 Mon Sep 17 00:00:00 2001 From: Matthias Wahl Date: Tue, 27 Oct 2020 15:10:02 +0100 Subject: [PATCH 3/3] Add custom de-/serializers to Version in order to serialize it to string. --- src/version.rs | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/version.rs b/src/version.rs index 2616384b..a5870543 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1,3 +1,4 @@ +use serde::{de::Error, de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; /// The version of the HTTP protocol in use. #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] #[non_exhaustive] @@ -18,6 +19,55 @@ pub enum Version { Http3_0, } +impl Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +struct VersionVisitor; + +impl<'de> Visitor<'de> for VersionVisitor { + type Value = Version; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a HTTP version as &str") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + match v { + "HTTP/0.9" => Ok(Version::Http0_9), + "HTTP/1.0" => Ok(Version::Http1_0), + "HTTP/1.1" => Ok(Version::Http1_1), + "HTTP/2" => Ok(Version::Http2_0), + "HTTP/3" => Ok(Version::Http3_0), + _ => Err(Error::invalid_value(serde::de::Unexpected::Str(v), &self)), + } + } + + fn visit_string(self, v: String) -> Result + where + E: Error, + { + self.visit_str(&v) + } +} + +impl<'de> Deserialize<'de> for Version { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(VersionVisitor) + } +} + impl std::fmt::Display for Version { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(match self { @@ -54,4 +104,11 @@ mod test { assert!(Http1_1 > Http1_0); assert!(Http1_0 > Http0_9); } + + #[test] + fn serde() -> Result<(), serde_json::Error> { + assert_eq!("\"HTTP/3\"", serde_json::to_string(&Version::Http3_0)?); + assert_eq!(Version::Http1_1, serde_json::from_str("\"HTTP/1.1\"")?); + Ok(()) + } }