Skip to content

Commit

Permalink
IComparable and IEquatable implementations for PyInt, PyFloat, and Py…
Browse files Browse the repository at this point in the history
…String for primitive .NET types
  • Loading branch information
lostmsu committed Feb 28, 2024
1 parent 9d18a24 commit f738505
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].

- Added `ToPythonAs<T>()` extension method to allow for explicit conversion using a specific type. ([#2311][i2311])

- Added `IComparable` and `IEquatable` implementations to `PyInt`, `PyFloat`, and `PyString`
to compare with primitive .NET types like `long`.

### Changed

### Fixed
Expand Down
27 changes: 27 additions & 0 deletions src/embed_tests/TestPyFloat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,32 @@ public void AsFloatBad()
StringAssert.StartsWith("could not convert string to float", ex.Message);
Assert.IsNull(a);
}

[Test]
public void CompareTo()
{
var v = new PyFloat(42);

Assert.AreEqual(0, v.CompareTo(42f));
Assert.AreEqual(0, v.CompareTo(42d));

Assert.AreEqual(1, v.CompareTo(41f));
Assert.AreEqual(1, v.CompareTo(41d));

Assert.AreEqual(-1, v.CompareTo(43f));
Assert.AreEqual(-1, v.CompareTo(43d));
}

[Test]
public void Equals()
{
var v = new PyFloat(42);

Assert.IsTrue(v.Equals(42f));
Assert.IsTrue(v.Equals(42d));

Assert.IsFalse(v.Equals(41f));
Assert.IsFalse(v.Equals(41d));
}
}
}
70 changes: 70 additions & 0 deletions src/embed_tests/TestPyInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,76 @@ public void ToBigInteger()
CollectionAssert.AreEqual(expected, actual);
}

[Test]
public void CompareTo()
{
var v = new PyInt(42);

#region Signed
Assert.AreEqual(0, v.CompareTo(42L));
Assert.AreEqual(0, v.CompareTo(42));
Assert.AreEqual(0, v.CompareTo((short)42));
Assert.AreEqual(0, v.CompareTo((sbyte)42));

Assert.AreEqual(1, v.CompareTo(41L));
Assert.AreEqual(1, v.CompareTo(41));
Assert.AreEqual(1, v.CompareTo((short)41));
Assert.AreEqual(1, v.CompareTo((sbyte)41));

Assert.AreEqual(-1, v.CompareTo(43L));
Assert.AreEqual(-1, v.CompareTo(43));
Assert.AreEqual(-1, v.CompareTo((short)43));
Assert.AreEqual(-1, v.CompareTo((sbyte)43));
#endregion Signed

#region Unsigned
Assert.AreEqual(0, v.CompareTo(42UL));
Assert.AreEqual(0, v.CompareTo(42U));
Assert.AreEqual(0, v.CompareTo((ushort)42));
Assert.AreEqual(0, v.CompareTo((byte)42));

Assert.AreEqual(1, v.CompareTo(41UL));
Assert.AreEqual(1, v.CompareTo(41U));
Assert.AreEqual(1, v.CompareTo((ushort)41));
Assert.AreEqual(1, v.CompareTo((byte)41));

Assert.AreEqual(-1, v.CompareTo(43UL));
Assert.AreEqual(-1, v.CompareTo(43U));
Assert.AreEqual(-1, v.CompareTo((ushort)43));
Assert.AreEqual(-1, v.CompareTo((byte)43));
#endregion Unsigned
}

[Test]
public void Equals()
{
var v = new PyInt(42);

#region Signed
Assert.True(v.Equals(42L));
Assert.True(v.Equals(42));
Assert.True(v.Equals((short)42));
Assert.True(v.Equals((sbyte)42));

Assert.False(v.Equals(41L));
Assert.False(v.Equals(41));
Assert.False(v.Equals((short)41));
Assert.False(v.Equals((sbyte)41));
#endregion Signed

#region Unsigned
Assert.True(v.Equals(42UL));
Assert.True(v.Equals(42U));
Assert.True(v.Equals((ushort)42));
Assert.True(v.Equals((byte)42));

Assert.False(v.Equals(41UL));
Assert.False(v.Equals(41U));
Assert.False(v.Equals((ushort)41));
Assert.False(v.Equals((byte)41));
#endregion Unsigned
}

[Test]
public void ToBigIntegerLarge()
{
Expand Down
19 changes: 19 additions & 0 deletions src/embed_tests/TestPyString.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,24 @@ public void TestUnicodeSurrogate()
Assert.AreEqual(4, actual.Length());
Assert.AreEqual(expected, actual.ToString());
}

[Test]
public void CompareTo()
{
var a = new PyString("foo");

Assert.AreEqual(0, a.CompareTo("foo"));
Assert.AreEqual("foo".CompareTo("bar"), a.CompareTo("bar"));
Assert.AreEqual("foo".CompareTo("foz"), a.CompareTo("foz"));
}

[Test]
public void Equals()
{
var a = new PyString("foo");

Assert.True(a.Equals("foo"));
Assert.False(a.Equals("bar"));
}
}
}
34 changes: 34 additions & 0 deletions src/runtime/PythonTypes/PyFloat.IComparable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;

namespace Python.Runtime;

partial class PyFloat : IComparable<double>, IComparable<float>
, IEquatable<double>, IEquatable<float>
, IComparable<PyFloat?>, IEquatable<PyFloat?>
{
public override bool Equals(object o)
{
using var _ = Py.GIL();
return o switch
{
double f64 => this.Equals(f64),
float f32 => this.Equals(f32),
_ => base.Equals(o),
};
}

public int CompareTo(double other) => this.ToDouble().CompareTo(other);

public int CompareTo(float other) => this.ToDouble().CompareTo(other);

public bool Equals(double other) => this.ToDouble().Equals(other);

public bool Equals(float other) => this.ToDouble().Equals(other);

public int CompareTo(PyFloat? other)
{
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
}

public bool Equals(PyFloat? other) => base.Equals(other);
}
4 changes: 3 additions & 1 deletion src/runtime/PythonTypes/PyFloat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Python.Runtime
/// PY3: https://docs.python.org/3/c-api/float.html
/// for details.
/// </summary>
public class PyFloat : PyNumber
public partial class PyFloat : PyNumber
{
internal PyFloat(in StolenReference ptr) : base(ptr)
{
Expand Down Expand Up @@ -100,6 +100,8 @@ public static PyFloat AsFloat(PyObject value)
return new PyFloat(op.Steal());
}

public double ToDouble() => Runtime.PyFloat_AsDouble(obj);

public override TypeCode GetTypeCode() => TypeCode.Double;
}
}
136 changes: 136 additions & 0 deletions src/runtime/PythonTypes/PyInt.IComparable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using System;

namespace Python.Runtime;

partial class PyInt : IComparable<long>, IComparable<int>, IComparable<sbyte>, IComparable<short>
, IComparable<ulong>, IComparable<uint>, IComparable<ushort>, IComparable<byte>
, IEquatable<long>, IEquatable<int>, IEquatable<short>, IEquatable<sbyte>
, IEquatable<ulong>, IEquatable<uint>, IEquatable<ushort>, IEquatable<byte>
, IComparable<PyInt?>, IEquatable<PyInt?>
{
public override bool Equals(object o)
{
using var _ = Py.GIL();
return o switch
{
long i64 => this.Equals(i64),
int i32 => this.Equals(i32),
short i16 => this.Equals(i16),
sbyte i8 => this.Equals(i8),

ulong u64 => this.Equals(u64),
uint u32 => this.Equals(u32),
ushort u16 => this.Equals(u16),
byte u8 => this.Equals(u8),

_ => base.Equals(o),
};
}

#region Signed
public int CompareTo(long other)
{
using var pyOther = Runtime.PyInt_FromInt64(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(int other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(short other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(sbyte other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public bool Equals(long other)
{
using var pyOther = Runtime.PyInt_FromInt64(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(int other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(short other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(sbyte other)
{
using var pyOther = Runtime.PyInt_FromInt32(other);
return this.Equals(pyOther.BorrowOrThrow());
}
#endregion Signed

#region Unsigned
public int CompareTo(ulong other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(uint other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(ushort other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public int CompareTo(byte other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.CompareTo(pyOther.BorrowOrThrow());
}

public bool Equals(ulong other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(uint other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(ushort other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}

public bool Equals(byte other)
{
using var pyOther = Runtime.PyLong_FromUnsignedLongLong(other);
return this.Equals(pyOther.BorrowOrThrow());
}
#endregion Unsigned

public int CompareTo(PyInt? other)
{
return other is null ? 1 : this.CompareTo(other.BorrowNullable());
}

public bool Equals(PyInt? other) => base.Equals(other);
}
2 changes: 1 addition & 1 deletion src/runtime/PythonTypes/PyInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Python.Runtime
/// Represents a Python integer object.
/// See the documentation at https://docs.python.org/3/c-api/long.html
/// </summary>
public class PyInt : PyNumber, IFormattable
public partial class PyInt : PyNumber, IFormattable
{
internal PyInt(in StolenReference ptr) : base(ptr)
{
Expand Down
19 changes: 18 additions & 1 deletion src/runtime/PythonTypes/PyObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,23 @@ public long Refcount
}
}

internal int CompareTo(BorrowedReference other)
{
int greater = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_GT);
Debug.Assert(greater != -1);
if (greater > 0)
return 1;
int less = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_LT);
Debug.Assert(less != -1);
return less > 0 ? -1 : 0;
}

internal bool Equals(BorrowedReference other)
{
int equal = Runtime.PyObject_RichCompareBool(this.Reference, other, Runtime.Py_EQ);
Debug.Assert(equal != -1);
return equal > 0;
}

public override bool TryGetMember(GetMemberBinder binder, out object? result)
{
Expand Down Expand Up @@ -1325,7 +1342,7 @@ private bool TryCompare(PyObject arg, int op, out object @out)
}
return true;
}

public override bool TryBinaryOperation(BinaryOperationBinder binder, object arg, out object? result)
{
using var _ = Py.GIL();
Expand Down
Loading

0 comments on commit f738505

Please sign in to comment.