Skip to content

Commit 884469d

Browse files
authored
Merge pull request #107 from robomics/fix/stripe-class
Improve parameter validation in Stripe class
2 parents c5c69bd + f725a87 commit 884469d

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

src/stripepy/utils/stripe.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(
9999

100100
self._left_bound = None
101101
self._right_bound = None
102+
self._bottom_bound = None
103+
self._top_bound = None
102104
if horizontal_bounds is not None:
103105
self.set_horizontal_bounds(*horizontal_bounds)
104106

@@ -127,6 +129,18 @@ def _infer_location(seed: int, top_bound: int, bottom_bound: int) -> str:
127129

128130
NotImplementedError
129131

132+
@staticmethod
133+
def _validate_vertical_bounds(left_bound: int, right_bound: int, top_bound: int, bottom_bound: int, location: str):
134+
assert location in {"upper_triangular", "lower_triangular"}
135+
if location == "lower_triangular" and not (left_bound <= top_bound <= right_bound):
136+
raise ValueError(
137+
f"top bound is not enclosed between the left and right bounds: {left_bound=}, {right_bound=}, {top_bound=}"
138+
)
139+
elif location == "upper_triangular" and not (left_bound <= bottom_bound <= right_bound):
140+
raise ValueError(
141+
f"bottom bound is not enclosed between the left and right bounds: {left_bound=}, {right_bound=}, {bottom_bound=}"
142+
)
143+
130144
def _compute_convex_comp(self) -> int:
131145
cfx1 = 0.99
132146
cfx2 = 0.01
@@ -190,7 +204,13 @@ def _compute_rmean(self, I: ss.csr_matrix, window: int) -> float:
190204
return submatrix.mean()
191205

192206
def _all_bounds_set(self) -> bool:
193-
return all((x is not None for x in [self._left_bound, self._right_bound, self._bottom_bound, self._top_bound]))
207+
return self._horizontal_bounds_set() and self._vertical_bounds_set()
208+
209+
def _horizontal_bounds_set(self) -> bool:
210+
return self._left_bound is not None and self._right_bound is not None
211+
212+
def _vertical_bounds_set(self) -> bool:
213+
return self._top_bound is not None and self._bottom_bound is not None
194214

195215
@property
196216
def seed(self) -> int:
@@ -313,11 +333,17 @@ def set_horizontal_bounds(self, left_bound: int, right_bound: int):
313333
assert self._right_bound is not None
314334
raise RuntimeError("horizontal stripe bounds have already been set")
315335

336+
if left_bound < 0 or right_bound < 0:
337+
raise ValueError("stripe bounds must be positive integers")
338+
316339
if not left_bound <= self._seed <= right_bound:
317340
raise ValueError(
318341
f"horizontal bounds must enclose the seed position: seed={self._seed}, {left_bound=}, {right_bound=}"
319342
)
320343

344+
if self._vertical_bounds_set():
345+
Stripe._validate_vertical_bounds(left_bound, right_bound, self._top_bound, self._bottom_bound, self._where)
346+
321347
self._left_bound = left_bound
322348
self._right_bound = right_bound
323349

@@ -336,21 +362,29 @@ def set_vertical_bounds(self, top_bound: int, bottom_bound: int):
336362
assert self._top_bound is not None
337363
raise RuntimeError("vertical stripe bounds have already been set")
338364

365+
if top_bound < 0 or bottom_bound < 0:
366+
raise ValueError("stripe bounds must be positive integers")
367+
339368
if top_bound > bottom_bound:
340369
raise ValueError(
341370
f"the lower vertical bound must be greater than the upper vertical bound: {top_bound=}, {bottom_bound=}"
342371
)
343372

344-
self._top_bound = top_bound
345-
self._bottom_bound = bottom_bound
346-
347-
computed_where = self._infer_location(self._seed, self._top_bound, self._bottom_bound)
373+
computed_where = self._infer_location(self._seed, top_bound, bottom_bound)
348374

349375
if self._where is not None and computed_where != self._where:
350376
raise RuntimeError(
351377
f"computed location does not match the provided stripe location: computed={computed_where}, expected={self._where}"
352378
)
353379

380+
if self._horizontal_bounds_set():
381+
Stripe._validate_vertical_bounds(
382+
self._left_bound, self._right_bound, top_bound, bottom_bound, computed_where
383+
)
384+
385+
self._top_bound = top_bound
386+
self._bottom_bound = bottom_bound
387+
354388
self._where = computed_where
355389

356390
def compute_biodescriptors(self, I: ss.csr_matrix, window: int = 3):

test/unit/test_IO.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_getters(self):
9595
def test_stripe_getters(self):
9696
res = Result("chr1", 123)
9797

98-
stripes = [Stripe(10, 1.23)]
98+
stripes = [Stripe(10, 1.23, where="upper_triangular")]
9999
stripes[0].set_vertical_bounds(5, 10)
100100
stripes[0].set_horizontal_bounds(8, 12)
101101
stripes[0].compute_biodescriptors(csr_matrix((15, 15), dtype=float))

0 commit comments

Comments
 (0)