Skip to content

Commit

Permalink
improving special indexing string representation (cvxpy#2560)
Browse files Browse the repository at this point in the history
* improving special indexing string representation

* changing array dtype to int to pass test

* removing test with numpy array due to weird failure for python3.9 str output

---------

Co-authored-by: William Zijie Zhang <william@gridmatic.com>
  • Loading branch information
Transurgeon and William Zijie Zhang authored Sep 12, 2024
1 parent 0f09be0 commit d6a8377
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
12 changes: 6 additions & 6 deletions cvxpy/atoms/affine/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ def is_atom_log_log_concave(self) -> bool:
"""
return True

# The string representation of the atom.
def name(self):
# TODO string should be orig_key
"""String representation of the index expression."""
inner_str = "[%s" + ", %s"*(len(self.key)-1) + "]"
return self.args[0].name() + inner_str % ku.to_str(self.key)

def numeric(self, values):
""" Returns the index/slice into the given value.
"""Returns the index/slice into the given value.
"""
return values[0][self._orig_key]

Expand Down Expand Up @@ -138,12 +137,13 @@ def is_atom_log_log_concave(self) -> bool:
"""
return True

# The string representation of the atom.
def name(self):
return self.args[0].name() + str(self.key)
"""String representation of the special index expression."""
key_str = ku.special_key_to_str(self.key)
return f"{self.args[0].name()}[{key_str}]"

def numeric(self, values):
""" Returns the index/slice into the given value.
"""Returns the index/slice into the given value.
"""
return values[0][self.key]

Expand Down
5 changes: 5 additions & 0 deletions cvxpy/tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ def test_index_expression(self) -> None:
self.assertEqual(exp.curvature, s.AFFINE)
self.assertEqual(exp.shape, (1,))

def test_special_idx_str_repr(self) -> None:
idx = [i for i in range(178)]
exp = cp.Variable((200, 10), name="exp")[idx, 6]
self.assertEqual("exp[[0, 1, 2, '...', 175, 176, 177], 6]", str(exp))

def test_none_idx(self) -> None:
"""Test None as index.
"""
Expand Down
20 changes: 20 additions & 0 deletions cvxpy/utilities/key_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,23 @@ def is_special_slice(key) -> bool:
return True

return False


def special_key_to_str(key):
"""Converts a special key to a string representation."""
key_strs = []
for k in key:
if isinstance(k, (np.ndarray, list)):
key_strs.append(str(pprint_sequence(k)))
elif isinstance(k, slice):
key_strs.append(slice_to_str(k))
else:
key_strs.append(str(k))
return ", ".join(key_strs)


def pprint_sequence(seq, max_elems=6):
"""Shorten the sequence (array or list) for pretty-printing."""
if len(seq) > max_elems:
return list(seq[:3]) + ['...'] + list(seq[-3:])
return seq

0 comments on commit d6a8377

Please sign in to comment.