@@ -99,6 +99,8 @@ def __init__(
99
99
100
100
self ._left_bound = None
101
101
self ._right_bound = None
102
+ self ._bottom_bound = None
103
+ self ._top_bound = None
102
104
if horizontal_bounds is not None :
103
105
self .set_horizontal_bounds (* horizontal_bounds )
104
106
@@ -127,6 +129,18 @@ def _infer_location(seed: int, top_bound: int, bottom_bound: int) -> str:
127
129
128
130
NotImplementedError
129
131
132
+ @staticmethod
133
+ def _validate_vertical_bounds (left_bound : int , right_bound : int , top_bound : int , bottom_bound : int , location : str ):
134
+ assert location in {"upper_triangular" , "lower_triangular" }
135
+ if location == "lower_triangular" and not (left_bound <= top_bound <= right_bound ):
136
+ raise ValueError (
137
+ f"top bound is not enclosed between the left and right bounds: { left_bound = } , { right_bound = } , { top_bound = } "
138
+ )
139
+ elif location == "upper_triangular" and not (left_bound <= bottom_bound <= right_bound ):
140
+ raise ValueError (
141
+ f"bottom bound is not enclosed between the left and right bounds: { left_bound = } , { right_bound = } , { bottom_bound = } "
142
+ )
143
+
130
144
def _compute_convex_comp (self ) -> int :
131
145
cfx1 = 0.99
132
146
cfx2 = 0.01
@@ -190,7 +204,13 @@ def _compute_rmean(self, I: ss.csr_matrix, window: int) -> float:
190
204
return submatrix .mean ()
191
205
192
206
def _all_bounds_set (self ) -> bool :
193
- return all ((x is not None for x in [self ._left_bound , self ._right_bound , self ._bottom_bound , self ._top_bound ]))
207
+ return self ._horizontal_bounds_set () and self ._vertical_bounds_set ()
208
+
209
+ def _horizontal_bounds_set (self ) -> bool :
210
+ return self ._left_bound is not None and self ._right_bound is not None
211
+
212
+ def _vertical_bounds_set (self ) -> bool :
213
+ return self ._top_bound is not None and self ._bottom_bound is not None
194
214
195
215
@property
196
216
def seed (self ) -> int :
@@ -313,11 +333,17 @@ def set_horizontal_bounds(self, left_bound: int, right_bound: int):
313
333
assert self ._right_bound is not None
314
334
raise RuntimeError ("horizontal stripe bounds have already been set" )
315
335
336
+ if left_bound < 0 or right_bound < 0 :
337
+ raise ValueError ("stripe bounds must be positive integers" )
338
+
316
339
if not left_bound <= self ._seed <= right_bound :
317
340
raise ValueError (
318
341
f"horizontal bounds must enclose the seed position: seed={ self ._seed } , { left_bound = } , { right_bound = } "
319
342
)
320
343
344
+ if self ._vertical_bounds_set ():
345
+ Stripe ._validate_vertical_bounds (left_bound , right_bound , self ._top_bound , self ._bottom_bound , self ._where )
346
+
321
347
self ._left_bound = left_bound
322
348
self ._right_bound = right_bound
323
349
@@ -336,21 +362,29 @@ def set_vertical_bounds(self, top_bound: int, bottom_bound: int):
336
362
assert self ._top_bound is not None
337
363
raise RuntimeError ("vertical stripe bounds have already been set" )
338
364
365
+ if top_bound < 0 or bottom_bound < 0 :
366
+ raise ValueError ("stripe bounds must be positive integers" )
367
+
339
368
if top_bound > bottom_bound :
340
369
raise ValueError (
341
370
f"the lower vertical bound must be greater than the upper vertical bound: { top_bound = } , { bottom_bound = } "
342
371
)
343
372
344
- self ._top_bound = top_bound
345
- self ._bottom_bound = bottom_bound
346
-
347
- computed_where = self ._infer_location (self ._seed , self ._top_bound , self ._bottom_bound )
373
+ computed_where = self ._infer_location (self ._seed , top_bound , bottom_bound )
348
374
349
375
if self ._where is not None and computed_where != self ._where :
350
376
raise RuntimeError (
351
377
f"computed location does not match the provided stripe location: computed={ computed_where } , expected={ self ._where } "
352
378
)
353
379
380
+ if self ._horizontal_bounds_set ():
381
+ Stripe ._validate_vertical_bounds (
382
+ self ._left_bound , self ._right_bound , top_bound , bottom_bound , computed_where
383
+ )
384
+
385
+ self ._top_bound = top_bound
386
+ self ._bottom_bound = bottom_bound
387
+
354
388
self ._where = computed_where
355
389
356
390
def compute_biodescriptors (self , I : ss .csr_matrix , window : int = 3 ):
0 commit comments