Skip to content

Commit

Permalink
Raise a ValueError when calling bool() on an ArraySymbol
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Jul 19, 2024
1 parent 786640f commit a84fbee
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dwave/optimization/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,11 @@ cdef class ArraySymbol(Symbol):

return NotImplemented

def __bool__(self):
# In the future we might want to return a Bool symbol, but __bool__ is so
# fundamental that I am hesitant to do even that.
raise ValueError("the truth value of an array symbol is ambiguous")

def __eq__(self, rhs):
if isinstance(rhs, ArraySymbol):
# We could consider returning a Constant(True) is the case that self is rhs
Expand Down
2 changes: 2 additions & 0 deletions releasenotes/notes/improve-operators-54e60d66200526cc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ features:
In-place multiplication with a ``NaryMultiply`` symbol will no longer create a new symbol.
fixes:
- Return ``NotImplemented`` from ``ArraySymbol`` operator methods for unknown types.
upgrade:
- Raise a ``ValueError`` when calling ``bool()`` on an ``ArraySymbol``.
4 changes: 3 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def smart_reduce(f, items):
[shifts[day][shift] for day in range(num_days)]
)
minutes_worked = total_shifts_worked * minutes_per_shift[shift]
total_minutes_worked = total_minutes_worked + minutes_worked if total_minutes_worked else minutes_worked
total_minutes_worked = (total_minutes_worked + minutes_worked
if total_minutes_worked is not None
else minutes_worked)

min_minutes = model.constant(min_minutes_per_week)
max_minutes = model.constant(max_minutes_per_week)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def test_abstract(self):
with self.assertRaisesRegex(ValueError, "ArraySymbols cannot be constructed directly"):
ArraySymbol()

def test_bool(self):
from dwave.optimization.model import ArraySymbol

# bypass the init, this should be done very carefully as it can lead to
# segfaults dependig on what methods are accessed!
symbol = ArraySymbol.__new__(ArraySymbol)

with self.assertRaises(ValueError):
bool(symbol)

def test_operator_types(self):
# For each, test that we get the right class from the operator and that
# incorrect types returns NotImplemented
Expand Down

0 comments on commit a84fbee

Please sign in to comment.