Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Implement FromStr for serde_yaml::Number #382

Merged
merged 2 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ fn parse_negative_int<T>(
from_str_radix(scalar, 10).ok()
}

fn parse_f64(scalar: &str) -> Option<f64> {
pub(crate) fn parse_f64(scalar: &str) -> Option<f64> {
let unpositive = if let Some(unpositive) = scalar.strip_prefix('+') {
if unpositive.starts_with(['+', '-']) {
return None;
Expand All @@ -1089,14 +1089,14 @@ fn parse_f64(scalar: &str) -> Option<f64> {
None
}

fn digits_but_not_number(scalar: &str) -> bool {
pub(crate) fn digits_but_not_number(scalar: &str) -> bool {
// Leading zero(s) followed by numeric characters is a string according to
// the YAML 1.2 spec. https://yaml.org/spec/1.2/spec.html#id2761292
let scalar = scalar.strip_prefix(['-', '+']).unwrap_or(scalar);
scalar.len() > 1 && scalar.starts_with('0') && scalar[1..].bytes().all(|b| b.is_ascii_digit())
}

fn visit_int<'de, V>(visitor: V, v: &str) -> Result<Result<V::Value>, V>
pub(crate) fn visit_int<'de, V>(visitor: V, v: &str) -> Result<Result<V::Value>, V>
where
V: Visitor<'de>,
{
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub(crate) enum ErrorImpl {
ScalarInMergeElement,
SequenceInMergeElement,
EmptyTag,
FailedToParseNumber,

Shared(Arc<ErrorImpl>),
}
Expand Down Expand Up @@ -239,6 +240,7 @@ impl ErrorImpl {
f.write_str("expected a mapping for merging, but found sequence")
}
ErrorImpl::EmptyTag => f.write_str("empty YAML tag is not allowed"),
ErrorImpl::FailedToParseNumber => f.write_str("failed to parse YAML number"),
ErrorImpl::Shared(_) => unreachable!(),
}
}
Expand Down
70 changes: 44 additions & 26 deletions src/number.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::Error;
use crate::de;
use crate::error::{self, Error, ErrorImpl};
use serde::de::{Unexpected, Visitor};
use serde::{forward_to_deserialize_any, Deserialize, Deserializer, Serialize, Serializer};
use std::cmp::Ordering;
use std::fmt::{self, Display};
use std::hash::{Hash, Hasher};
use std::str::FromStr;

/// Represents a YAML number, whether integer or floating point.
#[derive(Clone, PartialEq, PartialOrd)]
Expand Down Expand Up @@ -308,6 +310,22 @@ impl Display for Number {
}
}

impl FromStr for Number {
type Err = Error;

fn from_str(repr: &str) -> Result<Self, Self::Err> {
if let Ok(result) = de::visit_int(NumberVisitor, repr) {
return result;
}
if !de::digits_but_not_number(repr) {
if let Some(float) = de::parse_f64(repr) {
return Ok(float.into());
}
}
Err(error::new(ErrorImpl::FailedToParseNumber))
}
}

impl PartialEq for N {
fn eq(&self, other: &N) -> bool {
match (*self, *other) {
Expand Down Expand Up @@ -389,37 +407,37 @@ impl Serialize for Number {
}
}

impl<'de> Deserialize<'de> for Number {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Number, D::Error>
where
D: Deserializer<'de>,
{
struct NumberVisitor;
struct NumberVisitor;

impl<'de> Visitor<'de> for NumberVisitor {
type Value = Number;
impl<'de> Visitor<'de> for NumberVisitor {
type Value = Number;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a number")
}
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a number")
}

#[inline]
fn visit_i64<E>(self, value: i64) -> Result<Number, E> {
Ok(value.into())
}
#[inline]
fn visit_i64<E>(self, value: i64) -> Result<Number, E> {
Ok(value.into())
}

#[inline]
fn visit_u64<E>(self, value: u64) -> Result<Number, E> {
Ok(value.into())
}
#[inline]
fn visit_u64<E>(self, value: u64) -> Result<Number, E> {
Ok(value.into())
}

#[inline]
fn visit_f64<E>(self, value: f64) -> Result<Number, E> {
Ok(value.into())
}
}
#[inline]
fn visit_f64<E>(self, value: f64) -> Result<Number, E> {
Ok(value.into())
}
}

impl<'de> Deserialize<'de> for Number {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Number, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(NumberVisitor)
}
}
Expand Down
29 changes: 28 additions & 1 deletion tests/test_de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use indoc::indoc;
use serde_derive::Deserialize;
use serde_yaml::{Deserializer, Value};
use serde_yaml::{Deserializer, Number, Value};
use std::collections::BTreeMap;
use std::fmt::Debug;

Expand Down Expand Up @@ -676,3 +676,30 @@ fn test_tag_resolution() {

test_de(yaml, &expected);
}

#[test]
fn test_parse_number() {
let n = "111".parse::<Number>().unwrap();
assert_eq!(n, Number::from(111));

let n = "-111".parse::<Number>().unwrap();
assert_eq!(n, Number::from(-111));

let n = "-1.1".parse::<Number>().unwrap();
assert_eq!(n, Number::from(-1.1));

let n = ".nan".parse::<Number>().unwrap();
assert_eq!(n, Number::from(f64::NAN));

let n = ".inf".parse::<Number>().unwrap();
assert_eq!(n, Number::from(f64::INFINITY));

let n = "-.inf".parse::<Number>().unwrap();
assert_eq!(n, Number::from(f64::NEG_INFINITY));

let err = "null".parse::<Number>().unwrap_err();
assert_eq!(err.to_string(), "failed to parse YAML number");

let err = " 1 ".parse::<Number>().unwrap_err();
assert_eq!(err.to_string(), "failed to parse YAML number");
}