Skip to content

Commit

Permalink
Support <,>,==,!=,<=,>=,int,bool,not operations on Constant
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Jul 25, 2024
1 parent a84fbee commit 58d45aa
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 1 deletion.
19 changes: 18 additions & 1 deletion dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

import numpy.typing

from dwave.optimization.model import Symbol, ArraySymbol
Expand Down Expand Up @@ -46,7 +48,22 @@ class BinaryVariable(ArraySymbol):


class Constant(ArraySymbol):
...
def __bool__(self) -> bool: ...
def __index__(self) -> int: ...

# these methods don't (yet) have an ArraySymbol overload
def __ge__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...
def __gt__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...
def __lt__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...

@typing.overload
def __eq__(self, rhs: ArraySymbol) -> Equal: ...
@typing.overload
def __eq__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...
@typing.overload
def __le__(self, rhs: ArraySymbol) -> LessEqual: ...
@typing.overload
def __le__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...


class DisjointBitSets(Symbol):
Expand Down
56 changes: 56 additions & 0 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import collections.abc
import json
import numbers

cimport cpython.object
import cython
import numpy as np

from cpython.ref cimport PyObject
from cython.operator cimport dereference as deref, typeid
from libc.math cimport modf
from libcpp cimport bool
from libcpp.cast cimport dynamic_cast
from libcpp.optional cimport nullopt, optional
Expand Down Expand Up @@ -781,6 +783,12 @@ cdef class Constant(ArraySymbol):
# Have the parent model hold a reference to the array, so it's kept alive
model._data_sources.append(array)

def __bool__(self):
if not self._is_scalar():
raise ValueError("the truth value of a constant with more than one element is ambiguous")

return <bool>deref(self.ptr.buff())

def __getbuffer__(self, Py_buffer *buffer, int flags):
buffer.buf = <void*>(self.ptr.buff())
buffer.format = <char*>(self.ptr.format().c_str())
Expand All @@ -794,6 +802,54 @@ cdef class Constant(ArraySymbol):
buffer.strides = <Py_ssize_t*>(self.ptr.strides().data())
buffer.suboffsets = NULL

def __index__(self):
if not self._is_integer():
# Follow NumPy's error message
# https://github.com/numpy/numpy/blob/66e1e3/numpy/_core/src/multiarray/number.c#L833
raise TypeError("only integer scalar constants can be converted to a scalar index")

return <Py_ssize_t>deref(self.ptr.buff())

def __richcmp__(self, rhs, int op):
# __richcmp__ is a special Cython method

# If rhs is another Symbol, defer to ArraySymbol to handle the
# operation. Which may or may not actually be implemented.
# Otherwise, defer to NumPy.
# We could also check if rhs is another Constant and handle that differently,
# but that might lead to confusing behavior so we treat other Constants the
# same as any other symbol.
lhs = super() if isinstance(rhs, ArraySymbol) else np.asarray(self)

if op == cpython.object.Py_EQ:
return lhs.__eq__(rhs)
elif op == cpython.object.Py_GE:
return lhs.__ge__(rhs)
elif op == cpython.object.Py_GT:
return lhs.__gt__(rhs)
elif op == cpython.object.Py_LE:
return lhs.__le__(rhs)
elif op == cpython.object.Py_LT:
return lhs.__lt__(rhs)
elif op == cpython.object.Py_NE:
return lhs.__ne__(rhs)
else:
return NotImplemented # this should never happen, but just in case

cdef bool _is_integer(self) noexcept:
"""Return True if the constant encodes a single integer."""
if not self._is_scalar():
return False

# https://stackoverflow.com/q/1521607 for the integer test
cdef double dummy
return modf(deref(self.ptr.buff()), &dummy) == <double>0.0

cdef bool _is_scalar(self) noexcept:
"""Return True if the constant encodes a single value."""
# The size check is redundant, but worth checking in order to avoid segfaults
return self.ptr.size() == 1 and self.ptr.ndim() == 0

@staticmethod
def _from_symbol(Symbol symbol):
cdef cppConstantNode* ptr = dynamic_cast_ptr[cppConstantNode](symbol.node_ptr)
Expand Down
6 changes: 6 additions & 0 deletions releasenotes/notes/constant-comparisons-5258fa0aeb7f435e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- Add support for ``<,>,==,!=,<=,>=`` operators between ``Constant`` and array-like objects.
- |
Add support for ``bool(constant)``, ``int(constant)``, and ``not constant`` when
``constant`` is an instance of ``Constant`` encoding a single scalar value.
70 changes: 70 additions & 0 deletions tests/test_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import itertools
import math
import operator
import typing
import unittest

Expand Down Expand Up @@ -595,6 +596,60 @@ def generate_symbols(self):
yield A
yield B

def test_truthy(self):
model = Model()

self.assertTrue(model.constant(1))
self.assertFalse(model.constant(0))
self.assertTrue(model.constant(1.1))
self.assertFalse(model.constant(0.0))

self.assertTrue(not model.constant(0))
self.assertFalse(not model.constant(1))

# these are all ambiguous
with self.assertRaises(ValueError):
bool(model.constant([]))
with self.assertRaises(ValueError):
bool(model.constant([0, 1]))
with self.assertRaises(ValueError):
bool(model.constant([0]))

# the type is correct
self.assertIsInstance(model.constant(123.4).__bool__(), bool)

def test_comparisons(self):
model = Model()

# zero = model.constant(0)
one = model.constant(1)
onetwo = model.constant([1, 2])

operators = [
operator.eq,
operator.ge,
operator.gt,
operator.le,
operator.lt,
operator.ne
]

for op in operators:
with self.subTest(op):
self.assertEqual(op(one, 1), op(1, 1))
self.assertEqual(op(1, one), op(1, 1))

self.assertEqual(op(one, 2), op(1, 2))
self.assertEqual(op(2, one), op(2, 1))

self.assertEqual(op(one, 0), op(1, 0))
self.assertEqual(op(0, one), op(0, 1))

np.testing.assert_array_equal(op(onetwo, [1, 2]), op(np.asarray([1, 2]), [1, 2]))
np.testing.assert_array_equal(op(onetwo, [1, 0]), op(np.asarray([1, 2]), [1, 0]))
np.testing.assert_array_equal(op([1, 2], onetwo), op(np.asarray([1, 2]), [1, 2]))
np.testing.assert_array_equal(op([1, 0], onetwo), op(np.asarray([1, 0]), [1, 2]))

def test_copy(self):
model = Model()

Expand All @@ -604,6 +659,21 @@ def test_copy(self):
np.testing.assert_array_equal(A, arr)
self.assertTrue(np.shares_memory(A, arr))

def test_index(self):
model = Model()

self.assertEqual(list(range(model.constant(0))), [])
self.assertEqual(list(range(model.constant(4))), [0, 1, 2, 3])

with self.assertRaises(TypeError):
range(model.constant([0])) # not a scalar

self.assertEqual(int(model.constant(0)), 0)
self.assertEqual(int(model.constant(1)), 1)

with self.assertRaises(TypeError):
int(model.constant([0])) # not a scalar

def test_noncontiguous(self):
model = Model()
c = model.constant(np.arange(6)[::2])
Expand Down

0 comments on commit 58d45aa

Please sign in to comment.