Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed NaN comparison #647

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions numbat/src/interpreter/assert_eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use compact_str::{format_compact, CompactString};
use std::fmt::Display;
use thiserror::Error;

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
pub struct AssertEq2Error {
pub span_lhs: Span,
pub lhs: Value,
Expand All @@ -28,7 +28,7 @@ impl Display for AssertEq2Error {
}
}

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
pub struct AssertEq3Error {
pub span_lhs: Span,
pub lhs_original: Quantity,
Expand Down
4 changes: 2 additions & 2 deletions numbat/src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use thiserror::Error;

pub use crate::value::Value;

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
pub enum RuntimeError {
#[error("Division by zero")]
DivisionByZero,
Expand Down Expand Up @@ -66,7 +66,7 @@ pub enum RuntimeError {
FileWrite(std::path::PathBuf),
}

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq)]
#[must_use]
pub enum InterpreterResult {
Value(Value),
Expand Down
31 changes: 29 additions & 2 deletions numbat/src/quantity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ impl PartialEq for Quantity {
}
}

impl Eq for Quantity {}

impl PartialOrd for Quantity {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let other_converted = other.convert_to(self.unit()).ok()?;
Expand All @@ -347,7 +345,36 @@ impl PrettyPrint for Quantity {
}
}

pub(crate) enum QuantityOrdering {
IncompatibleUnits,
NanOperand,
Less,
Equal,
Greater,
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this lead to slightly less code?

Suggested change
pub(crate) enum QuantityOrdering {
IncompatibleUnits,
NanOperand,
Less,
Equal,
Greater,
}
pub(crate) enum QuantityOrderingResult {
IncompatibleUnits,
NanOperand,
Ok(std::cmp::Ordering),
}


impl Quantity {
/// partial_cmp that encodes whether comparison fails because its arguments have
/// incompatible units, or because one of them is NaN
pub(crate) fn partial_cmp_preserve_nan(&self, other: &Self) -> QuantityOrdering {
if self.value.to_f64().is_nan() || other.value.to_f64().is_nan() {
return QuantityOrdering::NanOperand;
}

let Ok(other_converted) = other.convert_to(self.unit()) else {
return QuantityOrdering::IncompatibleUnits;
};

match self.value.partial_cmp(&other_converted.value) {
Some(cmp) => match cmp {
std::cmp::Ordering::Less => QuantityOrdering::Less,
std::cmp::Ordering::Equal => QuantityOrdering::Equal,
std::cmp::Ordering::Greater => QuantityOrdering::Greater,
},
None => unreachable!("unexpectedly got a None partial_cmp from non-NaN arguments"),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe catch this with a .expect(…) above, using the same message?

}
}

/// Pretty prints with the given options.
/// If options is None, default options will be used.
fn pretty_print_with_options(&self, options: Option<FmtFloatConfig>) -> crate::markup::Markup {
Expand Down
2 changes: 1 addition & 1 deletion numbat/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl std::fmt::Display for FunctionReference {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Quantity(Quantity),
Boolean(bool),
Expand Down
34 changes: 20 additions & 14 deletions numbat/src/vm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{HashMap, VecDeque};
use std::fmt::Display;
use std::sync::Arc;
use std::{cmp::Ordering, fmt::Display};

use compact_str::{CompactString, ToCompactString};
use indexmap::IndexMap;
Expand Down Expand Up @@ -753,22 +753,28 @@ impl Vm {
self.push(ret);
}
op @ (Op::LessThan | Op::GreaterThan | Op::LessOrEqual | Op::GreatorOrEqual) => {
use crate::quantity::QuantityOrdering;

let rhs = self.pop_quantity();
let lhs = self.pop_quantity();

let result = lhs.partial_cmp(&rhs).ok_or_else(|| {
RuntimeError::QuantityError(QuantityError::IncompatibleUnits(
lhs.unit().clone(),
rhs.unit().clone(),
))
})?;

let result = match op {
Op::LessThan => result == Ordering::Less,
Op::GreaterThan => result == Ordering::Greater,
Op::LessOrEqual => result != Ordering::Greater,
Op::GreatorOrEqual => result != Ordering::Less,
_ => unreachable!(),
let result = match lhs.partial_cmp_preserve_nan(&rhs) {
QuantityOrdering::IncompatibleUnits => {
return Err(Box::new(RuntimeError::QuantityError(
QuantityError::IncompatibleUnits(
lhs.unit().clone(),
rhs.unit().clone(),
),
)))
}
QuantityOrdering::NanOperand => false,
QuantityOrdering::Less => matches!(op, Op::LessThan | Op::LessOrEqual),
QuantityOrdering::Equal => {
matches!(op, Op::LessOrEqual | Op::GreatorOrEqual)
}
QuantityOrdering::Greater => {
matches!(op, Op::GreaterThan | Op::GreatorOrEqual)
}
};

self.push(Value::Boolean(result));
Expand Down
28 changes: 28 additions & 0 deletions numbat/tests/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,34 @@ fn test_comparisons() {
expect_output("2 >= 2", "true");
expect_output("2 >= 2.1", "false");

// NaN comparison; all false

expect_output("NaN < NaN", "false");
expect_output("NaN < 0", "false");
expect_output("NaN < 0m", "false");
expect_output("0 < NaN", "false");
expect_output("0m < NaN", "false");

expect_output("NaN <= NaN", "false");
expect_output("NaN <= 0", "false");
expect_output("NaN <= 0m", "false");
expect_output("0 <= NaN", "false");
expect_output("0m <= NaN", "false");

expect_output("NaN > NaN", "false");
expect_output("NaN > 0", "false");
expect_output("NaN > 0m", "false");
expect_output("0 > NaN", "false");
expect_output("0m > NaN", "false");

expect_output("NaN >= NaN", "false");
expect_output("NaN >= 0", "false");
expect_output("NaN >= 0m", "false");
expect_output("0 >= NaN", "false");
expect_output("0m >= NaN", "false");

// equality

expect_output("200 cm == 2 m", "true");
expect_output("201 cm == 2 m", "false");

Expand Down
Loading