From ebab18302dd0b2d7d26685a3898e7065cbbbf2f0 Mon Sep 17 00:00:00 2001 From: Victor Nova Date: Thu, 15 Feb 2024 13:07:51 -0800 Subject: [PATCH] IComparable and IEquatable implementations for PyInt, PyFloat, and PyString for primitive .NET types --- CHANGELOG.md | 3 + src/embed_tests/TestPyFloat.cs | 27 ++++ src/embed_tests/TestPyInt.cs | 70 +++++++++ src/embed_tests/TestPyString.cs | 19 +++ .../PythonTypes/PyFloat.IComparable.cs | 34 +++++ src/runtime/PythonTypes/PyFloat.cs | 4 +- src/runtime/PythonTypes/PyInt.IComparable.cs | 136 ++++++++++++++++++ src/runtime/PythonTypes/PyInt.cs | 2 +- src/runtime/PythonTypes/PyObject.cs | 19 ++- src/runtime/PythonTypes/PyString.cs | 21 ++- 10 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 src/runtime/PythonTypes/PyFloat.IComparable.cs create mode 100644 src/runtime/PythonTypes/PyInt.IComparable.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f9d4bd1..e6cc52d72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][]. - Added `ToPythonAs()` extension method to allow for explicit conversion using a specific type. ([#2311][i2311]) +- Added `IComparable` and `IEquatable` implementations to `PyInt`, `PyFloat`, and `PyString` + to compare with primitive .NET types like `long`. + ### Changed ### Fixed diff --git a/src/embed_tests/TestPyFloat.cs b/src/embed_tests/TestPyFloat.cs index 36531cb6a..89e29e5fd 100644 --- a/src/embed_tests/TestPyFloat.cs +++ b/src/embed_tests/TestPyFloat.cs @@ -126,5 +126,32 @@ public void AsFloatBad() StringAssert.StartsWith("could not convert string to float", ex.Message); Assert.IsNull(a); } + + [Test] + public void CompareTo() + { + var v = new PyFloat(42); + + Assert.AreEqual(0, v.CompareTo(42f)); + Assert.AreEqual(0, v.CompareTo(42d)); + + Assert.AreEqual(1, v.CompareTo(41f)); + Assert.AreEqual(1, v.CompareTo(41d)); + + Assert.AreEqual(-1, v.CompareTo(43f)); + Assert.AreEqual(-1, v.CompareTo(43d)); + } + + [Test] + public void Equals() + { + var v = new PyFloat(42); + + Assert.IsTrue(v.Equals(42f)); + Assert.IsTrue(v.Equals(42d)); + + Assert.IsFalse(v.Equals(41f)); + Assert.IsFalse(v.Equals(41d)); + } } } diff --git a/src/embed_tests/TestPyInt.cs b/src/embed_tests/TestPyInt.cs index c147e074b..d2767e664 100644 --- a/src/embed_tests/TestPyInt.cs +++ b/src/embed_tests/TestPyInt.cs @@ -210,6 +210,76 @@ public void ToBigInteger() CollectionAssert.AreEqual(expected, actual); } + [Test] + public void CompareTo() + { + var v = new PyInt(42); + + #region Signed + Assert.AreEqual(0, v.CompareTo(42L)); + Assert.AreEqual(0, v.CompareTo(42)); + Assert.AreEqual(0, v.CompareTo((short)42)); + Assert.AreEqual(0, v.CompareTo((sbyte)42)); + + Assert.AreEqual(1, v.CompareTo(41L)); + Assert.AreEqual(1, v.CompareTo(41)); + Assert.AreEqual(1, v.CompareTo((short)41)); + Assert.AreEqual(1, v.CompareTo((sbyte)41)); + + Assert.AreEqual(-1, v.CompareTo(43L)); + Assert.AreEqual(-1, v.CompareTo(43)); + Assert.AreEqual(-1, v.CompareTo((short)43)); + Assert.AreEqual(-1, v.CompareTo((sbyte)43)); + #endregion Signed + + #region Unsigned + Assert.AreEqual(0, v.CompareTo(42UL)); + Assert.AreEqual(0, v.CompareTo(42U)); + Assert.AreEqual(0, v.CompareTo((ushort)42)); + Assert.AreEqual(0, v.CompareTo((byte)42)); + + Assert.AreEqual(1, v.CompareTo(41UL)); + Assert.AreEqual(1, v.CompareTo(41U)); + Assert.AreEqual(1, v.CompareTo((ushort)41)); + Assert.AreEqual(1, v.CompareTo((byte)41)); + + Assert.AreEqual(-1, v.CompareTo(43UL)); + Assert.AreEqual(-1, v.CompareTo(43U)); + Assert.AreEqual(-1, v.CompareTo((ushort)43)); + Assert.AreEqual(-1, v.CompareTo((byte)43)); + #endregion Unsigned + } + + [Test] + public void Equals() + { + var v = new PyInt(42); + + #region Signed + Assert.True(v.Equals(42L)); + Assert.True(v.Equals(42)); + Assert.True(v.Equals((short)42)); + Assert.True(v.Equals((sbyte)42)); + + Assert.False(v.Equals(41L)); + Assert.False(v.Equals(41)); + Assert.False(v.Equals((short)41)); + Assert.False(v.Equals((sbyte)41)); + #endregion Signed + + #region Unsigned + Assert.True(v.Equals(42UL)); + Assert.True(v.Equals(42U)); + Assert.True(v.Equals((ushort)42)); + Assert.True(v.Equals((byte)42)); + + Assert.False(v.Equals(41UL)); + Assert.False(v.Equals(41U)); + Assert.False(v.Equals((ushort)41)); + Assert.False(v.Equals((byte)41)); + #endregion Unsigned + } + [Test] public void ToBigIntegerLarge() { diff --git a/src/embed_tests/TestPyString.cs b/src/embed_tests/TestPyString.cs index b12e08c23..35c6339ee 100644 --- a/src/embed_tests/TestPyString.cs +++ b/src/embed_tests/TestPyString.cs @@ -112,5 +112,24 @@ public void TestUnicodeSurrogate() Assert.AreEqual(4, actual.Length()); Assert.AreEqual(expected, actual.ToString()); } + + [Test] + public void CompareTo() + { + var a = new PyString("foo"); + + Assert.AreEqual(0, a.CompareTo("foo")); + Assert.AreEqual("foo".CompareTo("bar"), a.CompareTo("bar")); + Assert.AreEqual("foo".CompareTo("foz"), a.CompareTo("foz")); + } + + [Test] + public void Equals() + { + var a = new PyString("foo"); + + Assert.True(a.Equals("foo")); + Assert.False(a.Equals("bar")); + } } } diff --git a/src/runtime/PythonTypes/PyFloat.IComparable.cs b/src/runtime/PythonTypes/PyFloat.IComparable.cs new file mode 100644 index 000000000..c12fc283a --- /dev/null +++ b/src/runtime/PythonTypes/PyFloat.IComparable.cs @@ -0,0 +1,34 @@ +using System; + +namespace Python.Runtime; + +partial class PyFloat : IComparable, IComparable + , IEquatable, IEquatable + , IComparable, IEquatable +{ + public override bool Equals(object o) + { + using var _ = Py.GIL(); + return o switch + { + double f64 => this.Equals(f64), + float f32 => this.Equals(f32), + _ => base.Equals(o), + }; + } + + public int CompareTo(double other) => this.ToDouble().CompareTo(other); + + public int CompareTo(float other) => this.ToDouble().CompareTo(other); + + public bool Equals(double other) => this.ToDouble().Equals(other); + + public bool Equals(float other) => this.ToDouble().Equals(other); + + public int CompareTo(PyFloat? other) + { + return other is null ? 1 : this.CompareTo(other.BorrowNullable()); + } + + public bool Equals(PyFloat? other) => base.Equals(other); +} diff --git a/src/runtime/PythonTypes/PyFloat.cs b/src/runtime/PythonTypes/PyFloat.cs index c09ec93ba..50621d5c2 100644 --- a/src/runtime/PythonTypes/PyFloat.cs +++ b/src/runtime/PythonTypes/PyFloat.cs @@ -8,7 +8,7 @@ namespace Python.Runtime /// PY3: https://docs.python.org/3/c-api/float.html /// for details. /// - public class PyFloat : PyNumber + public partial class PyFloat : PyNumber { internal PyFloat(in StolenReference ptr) : base(ptr) { @@ -100,6 +100,8 @@ public static PyFloat AsFloat(PyObject value) return new PyFloat(op.Steal()); } + public double ToDouble() => Runtime.PyFloat_AsDouble(obj); + public override TypeCode GetTypeCode() => TypeCode.Double; } } diff --git a/src/runtime/PythonTypes/PyInt.IComparable.cs b/src/runtime/PythonTypes/PyInt.IComparable.cs new file mode 100644 index 000000000..a96f02e10 --- /dev/null +++ b/src/runtime/PythonTypes/PyInt.IComparable.cs @@ -0,0 +1,136 @@ +using System; + +namespace Python.Runtime; + +partial class PyInt : IComparable, IComparable, IComparable, IComparable + , IComparable, IComparable, IComparable, IComparable + , IEquatable, IEquatable, IEquatable, IEquatable + , IEquatable, IEquatable, IEquatable, IEquatable + , IComparable, IEquatable +{ + public override bool Equals(object o) + { + using var _ = Py.GIL(); + return o switch + { + long i64 => this.Equals(i64), + int i32 => this.Equals(i32), + short i16 => this.Equals(i16), + sbyte i8 => this.Equals(i8), + + ulong u64 => this.Equals(u64), + uint u32 => this.Equals(u32), + ushort u16 => this.Equals(u16), + byte u8 => this.Equals(u8), + + _ => base.Equals(o), + }; + } + + #region Signed + public int CompareTo(long other) + { + using var pyOther = Runtime.PyInt_FromInt64(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(int other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(short other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(sbyte other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public bool Equals(long other) + { + using var pyOther = Runtime.PyInt_FromInt64(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(int other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(short other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(sbyte other) + { + using var pyOther = Runtime.PyInt_FromInt32(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + #endregion Signed + + #region Unsigned + public int CompareTo(ulong other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(uint other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(ushort other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public int CompareTo(byte other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.CompareTo(pyOther.BorrowOrThrow()); + } + + public bool Equals(ulong other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(uint other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(ushort other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + + public bool Equals(byte other) + { + using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other); + return this.Equals(pyOther.BorrowOrThrow()); + } + #endregion Unsigned + + public int CompareTo(PyInt? other) + { + return other is null ? 1 : this.CompareTo(other.BorrowNullable()); + } + + public bool Equals(PyInt? other) => base.Equals(other); +} diff --git a/src/runtime/PythonTypes/PyInt.cs b/src/runtime/PythonTypes/PyInt.cs index e71462b74..0d00f5a13 100644 --- a/src/runtime/PythonTypes/PyInt.cs +++ b/src/runtime/PythonTypes/PyInt.cs @@ -9,7 +9,7 @@ namespace Python.Runtime /// Represents a Python integer object. /// See the documentation at https://docs.python.org/3/c-api/long.html /// - public class PyInt : PyNumber, IFormattable + public partial class PyInt : PyNumber, IFormattable { internal PyInt(in StolenReference ptr) : base(ptr) { diff --git a/src/runtime/PythonTypes/PyObject.cs b/src/runtime/PythonTypes/PyObject.cs index bda2d9c02..cf0c2a03f 100644 --- a/src/runtime/PythonTypes/PyObject.cs +++ b/src/runtime/PythonTypes/PyObject.cs @@ -1136,6 +1136,23 @@ public long Refcount } } + internal int CompareTo(BorrowedReference other) + { + int greater = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_GT); + Debug.Assert(greater != -1); + if (greater > 0) + return 1; + int less = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_LT); + Debug.Assert(less != -1); + return less > 0 ? -1 : 0; + } + + internal bool Equals(BorrowedReference other) + { + int equal = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_EQ); + Debug.Assert(equal != -1); + return equal > 0; + } public override bool TryGetMember(GetMemberBinder binder, out object? result) { @@ -1325,7 +1342,7 @@ private bool TryCompare(PyObject arg, int op, out object @out) } return true; } - + public override bool TryBinaryOperation(BinaryOperationBinder binder, object arg, out object? result) { using var _ = Py.GIL(); diff --git a/src/runtime/PythonTypes/PyString.cs b/src/runtime/PythonTypes/PyString.cs index d54397fcf..6fed25c3e 100644 --- a/src/runtime/PythonTypes/PyString.cs +++ b/src/runtime/PythonTypes/PyString.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Runtime.Serialization; namespace Python.Runtime @@ -13,7 +14,7 @@ namespace Python.Runtime /// 2011-01-29: ...Then why does the string constructor call PyUnicode_FromUnicode()??? /// [Serializable] - public class PyString : PySequence + public class PyString : PySequence, IComparable, IEquatable { internal PyString(in StolenReference reference) : base(reference) { } internal PyString(BorrowedReference reference) : base(reference) { } @@ -61,5 +62,23 @@ public static bool IsStringType(PyObject value) } public override TypeCode GetTypeCode() => TypeCode.String; + + internal string ToStringUnderGIL() + { + string? result = Runtime.GetManagedString(this.Reference); + Debug.Assert(result is not null); + return result!; + } + + public bool Equals(string? other) + => this.ToStringUnderGIL().Equals(other, StringComparison.CurrentCulture); + public int CompareTo(string? other) + => string.Compare(this.ToStringUnderGIL(), other, StringComparison.CurrentCulture); + + public override string ToString() + { + using var _ = Py.GIL(); + return this.ToStringUnderGIL(); + } } }