Skip to content

Commit

Permalink
Merge pull request #262 from mfelsche/serialize
Browse files Browse the repository at this point in the history
Add custom deserialize and serialize impls for Version, StatusCode and Method
  • Loading branch information
yoshuawuyts authored Oct 28, 2020
2 parents 18b86df + ad513b5 commit a6b2a05
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/method.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use serde::de::{Error as DeError, Unexpected};
use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::{self, Display};
use std::str::FromStr;

Expand Down Expand Up @@ -50,6 +52,44 @@ impl Method {
}
}

struct MethodVisitor;

impl<'de> Visitor<'de> for MethodVisitor {
type Value = Method;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a HTTP method &str")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: DeError,
{
match Method::from_str(v) {
Ok(method) => Ok(method),
Err(_) => Err(DeError::invalid_value(Unexpected::Str(v), &self)),
}
}
}

impl<'de> Deserialize<'de> for Method {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(MethodVisitor)
}
}

impl Serialize for Method {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}

impl Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down Expand Up @@ -108,3 +148,20 @@ impl AsRef<str> for Method {
}
}
}

#[cfg(test)]
mod test {
use super::Method;

#[test]
fn serde() -> Result<(), serde_json::Error> {
assert_eq!(Method::Get, serde_json::from_str("\"GET\"")?);
assert_eq!(Some("PATCH"), serde_json::to_value(Method::Patch)?.as_str());
Ok(())
}
#[test]
fn serde_fail() -> Result<(), serde_json::Error> {
serde_json::from_str::<Method>("\"ABC\"").expect_err("Did deserialize from invalid string");
Ok(())
}
}
95 changes: 95 additions & 0 deletions src/status_code.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use serde::de::{Error as DeError, Unexpected, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt::{self, Display};

/// HTTP response status codes.
Expand Down Expand Up @@ -537,6 +539,84 @@ impl StatusCode {
}
}

impl Serialize for StatusCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let value: u16 = *self as u16;
serializer.serialize_u16(value)
}
}

struct StatusCodeU16Visitor;

impl<'de> Visitor<'de> for StatusCodeU16Visitor {
type Value = StatusCode;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a u16 representing the status code")
}

fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
where
E: DeError,
{
self.visit_u16(v as u16)
}

fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
where
E: DeError,
{
self.visit_u16(v as u16)
}

fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: DeError,
{
self.visit_u16(v as u16)
}

fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
where
E: DeError,
{
use std::convert::TryFrom;
match StatusCode::try_from(v) {
Ok(status_code) => Ok(status_code),
Err(_) => Err(DeError::invalid_value(
Unexpected::Unsigned(v as u64),
&self,
)),
}
}

fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
where
E: DeError,
{
self.visit_u16(v as u16)
}

fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: DeError,
{
self.visit_u16(v as u16)
}
}

impl<'de> Deserialize<'de> for StatusCode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(StatusCodeU16Visitor)
}
}

impl From<StatusCode> for u16 {
fn from(code: StatusCode) -> u16 {
code as u16
Expand Down Expand Up @@ -629,3 +709,18 @@ impl Display for StatusCode {
write!(f, "{}", *self as u16)
}
}

#[cfg(test)]
mod test {
use super::StatusCode;
#[test]
fn serde_as_u16() -> Result<(), serde_json::Error> {
let status_code: StatusCode = serde_json::from_str("202")?;
assert_eq!(StatusCode::Accepted, status_code);
assert_eq!(
Some(202),
serde_json::to_value(&StatusCode::Accepted)?.as_u64()
);
Ok(())
}
}
57 changes: 57 additions & 0 deletions src/version.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde::{de::Error, de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
/// The version of the HTTP protocol in use.
#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[non_exhaustive]
Expand All @@ -18,6 +19,55 @@ pub enum Version {
Http3_0,
}

impl Serialize for Version {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}

struct VersionVisitor;

impl<'de> Visitor<'de> for VersionVisitor {
type Value = Version;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a HTTP version as &str")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
match v {
"HTTP/0.9" => Ok(Version::Http0_9),
"HTTP/1.0" => Ok(Version::Http1_0),
"HTTP/1.1" => Ok(Version::Http1_1),
"HTTP/2" => Ok(Version::Http2_0),
"HTTP/3" => Ok(Version::Http3_0),
_ => Err(Error::invalid_value(serde::de::Unexpected::Str(v), &self)),
}
}

fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: Error,
{
self.visit_str(&v)
}
}

impl<'de> Deserialize<'de> for Version {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(VersionVisitor)
}
}

impl std::fmt::Display for Version {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Expand Down Expand Up @@ -54,4 +104,11 @@ mod test {
assert!(Http1_1 > Http1_0);
assert!(Http1_0 > Http0_9);
}

#[test]
fn serde() -> Result<(), serde_json::Error> {
assert_eq!("\"HTTP/3\"", serde_json::to_string(&Version::Http3_0)?);
assert_eq!(Version::Http1_1, serde_json::from_str("\"HTTP/1.1\"")?);
Ok(())
}
}

0 comments on commit a6b2a05

Please sign in to comment.