Skip to content

Commit

Permalink
Add datastructure support for complex numbers. (#163)
Browse files Browse the repository at this point in the history
* Create Complex datatype.

* Fix build.

* Polish wording.

* Implement eq and ord.
  • Loading branch information
wmedrano committed Feb 9, 2024
1 parent 4b8ceb1 commit 40cd353
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 99 deletions.
2 changes: 1 addition & 1 deletion crates/steel-core/src/parser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl TryFrom<&SteelVal> for ExprKind {
BigNum(x) => Ok(ExprKind::Atom(Atom::new(SyntaxObject::default(
IntegerLiteral(MaybeBigInt::Big(x.unwrap())),
)))),

Complex(_) => unimplemented!("Complex numbers not fully supported yet. See https://github.com/mattwparas/steel/issues/62 for current details."),
VectorV(lst) => {
let items: std::result::Result<Vec<ExprKind>, &'static str> =
lst.iter().map(|x| inner_try_from(x, depth + 1)).collect();
Expand Down
97 changes: 76 additions & 21 deletions crates/steel-core/src/primitives/nums.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::rvals::{IntoSteelVal, Result, SteelVal};
use crate::steel_vm::primitives::numberp;
use crate::rvals::{IntoSteelVal, Result, SteelComplex, SteelVal};
use crate::steel_vm::primitives::{numberp, realp};
use crate::stop;
use num::{BigInt, BigRational, CheckedAdd, CheckedMul, Integer, Rational32, ToPrimitive};
use std::ops::Neg;
Expand All @@ -17,7 +17,7 @@ fn ensure_args_are_numbers(op: &str, args: &[SteelVal]) -> Result<()> {
///
/// # Precondition
/// - `x` and `y` must be valid numerical types.
fn multiply_unchecked(x: &SteelVal, y: &SteelVal) -> Result<SteelVal> {
fn multiply_two(x: &SteelVal, y: &SteelVal) -> Result<SteelVal> {
match (x, y) {
(SteelVal::NumV(x), SteelVal::NumV(y)) => (x * y).into_steelval(),
(SteelVal::NumV(x), SteelVal::IntV(y)) | (SteelVal::IntV(y), SteelVal::NumV(x)) => {
Expand Down Expand Up @@ -91,6 +91,12 @@ fn multiply_unchecked(x: &SteelVal, y: &SteelVal) -> Result<SteelVal> {
(x.as_ref() * y.as_ref()).into_steelval()
}
(SteelVal::BigNum(x), SteelVal::BigNum(y)) => (x.as_ref() * y.as_ref()).into_steelval(),
// Complex numbers.
(SteelVal::Complex(x), SteelVal::Complex(y)) => multiply_complex(x, y),
(SteelVal::Complex(x), y) | (y, SteelVal::Complex(x)) => {
let y = SteelComplex::new(y.clone(), SteelVal::IntV(0));
multiply_complex(x, &y)
}
_ => unreachable!(),
}
}
Expand All @@ -101,13 +107,13 @@ fn multiply_primitive_impl(args: &[SteelVal]) -> Result<SteelVal> {
match args {
[] => 1.into_steelval(),
[x] => x.clone().into_steelval(),
[x, y] => multiply_unchecked(x, y).into_steelval(),
[x, y] => multiply_two(x, y).into_steelval(),
[x, y, zs @ ..] => {
let mut res = multiply_unchecked(x, y)?;
let mut res = multiply_two(x, y)?;
for z in zs {
// TODO: This use case could be optimized to reuse state instead of creating a new
// object each time.
res = multiply_unchecked(&res, &z)?;
res = multiply_two(&res, &z)?;
}
res.into_steelval()
}
Expand All @@ -134,8 +140,8 @@ pub fn divide_primitive(args: &[SteelVal]) -> Result<SteelVal> {
Err(_) => BigRational::new(BigInt::from(1), BigInt::from(*n)).into_steelval(),
},
SteelVal::NumV(n) => n.recip().into_steelval(),
SteelVal::Rational(f) => f.recip().into_steelval(),
SteelVal::BigRational(f) => f.recip().into_steelval(),
SteelVal::Rational(r) => r.recip().into_steelval(),
SteelVal::BigRational(r) => r.recip().into_steelval(),
SteelVal::BigNum(n) => BigRational::new(1.into(), n.as_ref().clone()).into_steelval(),
unexpected => {
stop!(TypeMismatch => "/ expects a number, but found: {:?}", unexpected)
Expand All @@ -146,18 +152,20 @@ pub fn divide_primitive(args: &[SteelVal]) -> Result<SteelVal> {
[] => stop!(ArityMismatch => "/ requires at least one argument"),
[x] => recip(x),
// TODO: Provide custom implementation to optimize by joining the multiply and recip calls.
[x, y] => multiply_unchecked(x, &recip(y)?),
[x, y] => multiply_two(x, &recip(y)?),
[x, ys @ ..] => {
let d = multiply_primitive_impl(ys)?;
multiply_unchecked(&x, &recip(&d)?)
multiply_two(&x, &recip(&d)?)
}
}
}

#[steel_derive::native(name = "-", constant = true, arity = "AtLeast(1)")]
pub fn subtract_primitive(args: &[SteelVal]) -> Result<SteelVal> {
ensure_args_are_numbers("-", args)?;
let negate = |x: &SteelVal| match x {
/// Negate a number.
///
/// # Precondition
/// `value` must be a number.
fn negate(value: &SteelVal) -> Result<SteelVal> {
match value {
SteelVal::NumV(x) => (-x).into_steelval(),
SteelVal::IntV(x) => match x.checked_neg() {
Some(res) => res.into_steelval(),
Expand All @@ -171,18 +179,28 @@ pub fn subtract_primitive(args: &[SteelVal]) -> Result<SteelVal> {
},
SteelVal::BigRational(x) => x.as_ref().neg().into_steelval(),
SteelVal::BigNum(x) => x.as_ref().clone().neg().into_steelval(),
SteelVal::Complex(x) => negate_complex(x),
_ => unreachable!(),
};
}
}

#[steel_derive::native(name = "-", constant = true, arity = "AtLeast(1)")]
pub fn subtract_primitive(args: &[SteelVal]) -> Result<SteelVal> {
ensure_args_are_numbers("-", args)?;
match args {
[] => stop!(TypeMismatch => "- requires at least one argument"),
[x] => negate(x),
[x, ys @ ..] => {
let y = negate(&add_primitive(ys)?)?;
add_primitive(&[x.clone(), y])
add_two(x, &y)
}
}
}

/// Adds two numbers.
///
/// # Precondition
/// x and y must be valid numbers.
pub fn add_two(x: &SteelVal, y: &SteelVal) -> Result<SteelVal> {
match (x, y) {
// Simple integer case. Probably very common.
Expand Down Expand Up @@ -259,6 +277,12 @@ pub fn add_two(x: &SteelVal, y: &SteelVal) -> Result<SteelVal> {
res += *y;
res.into_steelval()
}
// Complex numbers
(SteelVal::Complex(x), SteelVal::Complex(y)) => add_complex(x, y),
(SteelVal::Complex(x), y) | (y, SteelVal::Complex(x)) => {
debug_assert!(realp(y));
add_complex(x, &SteelComplex::new(y.clone(), SteelVal::IntV(0)))
}
_ => unreachable!(),
}
}
Expand All @@ -280,17 +304,48 @@ pub fn add_primitive(args: &[SteelVal]) -> Result<SteelVal> {
}
}

#[cold]
fn multiply_complex(x: &SteelComplex, y: &SteelComplex) -> Result<SteelVal> {
// TODO: Optimize the implementation if needed.
let real = add_two(
&multiply_two(&x.re, &y.re)?,
&negate(&multiply_two(&x.im, &y.im)?)?,
)?;
let im = add_two(&multiply_two(&x.re, &y.im)?, &multiply_two(&x.im, &y.re)?)?;
SteelComplex::new(real, im).into_steelval()
}

#[cold]
fn negate_complex(x: &SteelComplex) -> Result<SteelVal> {
// TODO: Optimize the implementation if needed.
SteelComplex::new(negate(&x.re)?, negate(&x.im)?).into_steelval()
}

#[cold]
fn add_complex(x: &SteelComplex, y: &SteelComplex) -> Result<SteelVal> {
// TODO: Optimize the implementation if needed.
SteelComplex::new(add_two(&x.re, &y.re)?, add_two(&x.im, &y.im)?).into_steelval()
}

#[steel_derive::function(name = "exact?", constant = true)]
pub fn exactp(value: &SteelVal) -> bool {
matches!(
value,
SteelVal::IntV(_) | SteelVal::BigNum(_) | SteelVal::Rational(_) | SteelVal::BigRational(_)
)
match value {
SteelVal::IntV(_)
| SteelVal::BigNum(_)
| SteelVal::Rational(_)
| SteelVal::BigRational(_) => true,
SteelVal::Complex(x) => exactp(&x.re) && exactp(&x.im),
_ => false,
}
}

#[steel_derive::function(name = "inexact?", constant = true)]
pub fn inexactp(value: &SteelVal) -> bool {
matches!(value, SteelVal::NumV(_))
match value {
SteelVal::NumV(_) => true,
SteelVal::Complex(x) => inexactp(&x.re) || inexactp(&x.im),
_ => false,
}
}

pub struct NumOperations {}
Expand Down
85 changes: 73 additions & 12 deletions crates/steel-core/src/rvals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use crate::{
tokens::TokenType,
},
rerrs::{ErrorKind, SteelErr},
steel_vm::vm::{threads::closure_into_serializable, BuiltInSignature, Continuation},
steel_vm::{
primitives::realp,
vm::{threads::closure_into_serializable, BuiltInSignature, Continuation},
},
values::port::SteelPort,
values::{
closed::{Heap, HeapRef, MarkAndSweepContext},
Expand All @@ -22,7 +25,7 @@ use crate::{
},
values::{functions::BoxedDynFunction, structs::UserDefinedStruct},
};

use std::vec::IntoIter;
use std::{
any::{Any, TypeId},
cell::{Ref, RefCell, RefMut},
Expand All @@ -40,8 +43,6 @@ use std::{
task::Context,
};

use std::vec::IntoIter;

// TODO
#[macro_export]
macro_rules! list {
Expand Down Expand Up @@ -71,7 +72,7 @@ use futures_util::future::Shared;
use futures_util::FutureExt;

use crate::values::lists::List;
use num::{BigInt, BigRational, Rational32, ToPrimitive};
use num::{BigInt, BigRational, Rational32, Signed, ToPrimitive, Zero};
use steel_parser::tokens::MaybeBigInt;

use self::cycles::{CycleDetector, IterativeDropHandler};
Expand Down Expand Up @@ -1216,6 +1217,53 @@ pub enum SteelVal {
BigNum(Gc<BigInt>),
// Like Rational but supports larger numerators and denominators.
BigRational(Gc<BigRational>),
// A complex number.
Complex(Gc<SteelComplex>),
}

/// Contains a complex number.
///
/// TODO: Optimize the contents of complex value. Holding `SteelVal` makes it easier to use existing
/// operations but a more specialized representation may be faster.
#[derive(Clone, Hash, PartialEq)]
pub struct SteelComplex {
/// The real part of the complex number.
pub re: SteelVal,
/// The imaginary part of the complex number.
pub im: SteelVal,
}

impl SteelComplex {
pub fn new(real: SteelVal, imaginary: SteelVal) -> SteelComplex {
SteelComplex {
re: real,
im: imaginary,
}
}
}

impl IntoSteelVal for SteelComplex {
fn into_steelval(self) -> Result<SteelVal> {
Ok(match self.im {
NumV(n) if n.is_zero() => self.re,
IntV(0) => self.re,
_ => SteelVal::Complex(Gc::new(self)),
})
}
}

impl SteelComplex {
/// Returns `true` if the imaginary part is negative.
fn imaginary_is_negative(&self) -> bool {
match &self.im {
NumV(x) => x.is_negative(),
IntV(x) => x.is_negative(),
Rational(x) => x.is_negative(),
BigNum(x) => x.is_negative(),
SteelVal::BigRational(x) => x.is_negative(),
_ => unreachable!(),
}
}
}

impl SteelVal {
Expand Down Expand Up @@ -1592,11 +1640,12 @@ impl Hash for SteelVal {
NumV(n) => n.to_string().hash(state),
IntV(i) => i.hash(state),
Rational(f) => f.hash(state),
BigNum(n) => n.hash(state),
BigRational(f) => f.hash(state),
Complex(x) => x.hash(state),
CharV(c) => c.hash(state),
ListV(l) => l.hash(state),
CustomStruct(s) => s.hash(state),
BigNum(n) => n.hash(state),
BigRational(f) => f.hash(state),
// Pair(cell) => {
// cell.hash(state);
// }
Expand Down Expand Up @@ -1959,10 +2008,14 @@ pub fn number_equality(left: &SteelVal, right: &SteelVal) -> Result<SteelVal> {
| (BigRational(_), BigNum(_))
| (BigNum(_), BigRational(_)) => false,
(IntV(_), BigNum(_)) | (BigNum(_), IntV(_)) => false,
(Complex(x), Complex(y)) => {
number_equality(&x.re, &y.re)? == BoolV(true)
&& number_equality(&x.im, &y.re)? == BoolV(true)
}
(Complex(_), _) | (_, Complex(_)) => false,
_ => stop!(TypeMismatch => "= expects two numbers, found: {:?} and {:?}", left, right),
};

Ok(SteelVal::BoolV(result))
Ok(BoolV(result))
}

fn partial_cmp_f64(l: &impl ToPrimitive, r: &impl ToPrimitive) -> Option<Ordering> {
Expand All @@ -1972,8 +2025,8 @@ fn partial_cmp_f64(l: &impl ToPrimitive, r: &impl ToPrimitive) -> Option<Orderin
// TODO add tests
impl PartialOrd for SteelVal {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
// TODO: Attempt to avoid converting to f64 for cases below as it may lead to precision
// loss at tiny and large values.
// TODO: Attempt to avoid converting to f64 for cases below as it may lead to precision loss
// at tiny and large values.
match (self, other) {
(IntV(l), IntV(r)) => l.partial_cmp(r),
(IntV(l), NumV(r)) => partial_cmp_f64(l, r),
Expand Down Expand Up @@ -2002,7 +2055,15 @@ impl PartialOrd for SteelVal {
(BigRational(l), BigNum(r)) => partial_cmp_f64(l.as_ref(), r.as_ref()),
(StringV(s), StringV(o)) => s.partial_cmp(o),
(CharV(l), CharV(r)) => l.partial_cmp(r),
_ => None, // unimplemented for other types
(l, r) => {
// All real numbers (not complex) should have order defined.
debug_assert!(
!(realp(l) && realp(r)),
"Numbers {l:?} and {r:?} should implement partial_cmp"
);
// Unimplemented for other types
None
}
}
}
}
Expand Down
Loading

0 comments on commit 40cd353

Please sign in to comment.