Skip to content

Commit 755abd5

Browse files
committed
refactor: Streamline join internals, encode indices via two separate arrays
1 parent e5f3eb5 commit 755abd5

File tree

2 files changed

+293
-384
lines changed

2 files changed

+293
-384
lines changed

bioframe/core/arrops.py

Lines changed: 117 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,10 @@ def overlap_intervals(starts1, ends1, starts2, ends2, closed=False, sort=False):
298298
If True, then treat intervals as closed and report single-point overlaps.
299299
Returns
300300
-------
301-
overlap_ids : numpy.ndarray
302-
An Nx2 array containing the indices of pairs of overlapping intervals.
303-
The 1st column contains ids from the 1st set, the 2nd column has ids
304-
from the 2nd set.
305-
301+
ovids1, ovids2 : numpy.ndarray
302+
Two 1D arrays containing the indices of pairs of overlapping intervals.
303+
The 1st contains ids from the 1st set, the 2nd has ids from the 2nd set.
306304
"""
307-
308305
for vec in [starts1, ends1, starts2, ends2]:
309306
if isinstance(vec, pd.Series):
310307
warnings.warn(
@@ -353,28 +350,21 @@ def overlap_intervals(starts1, ends1, starts2, ends2, closed=False, sort=False):
353350
)
354351

355352
# Generate IDs of pairs of overlapping intervals
356-
overlap_ids = np.block(
357-
[
358-
[
359-
np.repeat(ids1[match_2in1_mask], match_2in1_ends - match_2in1_starts)[
360-
:, None
361-
],
362-
ids2[arange_multi(match_2in1_starts, match_2in1_ends)][:, None],
363-
],
364-
[
365-
ids1[arange_multi(match_1in2_starts, match_1in2_ends)][:, None],
366-
np.repeat(ids2[match_1in2_mask], match_1in2_ends - match_1in2_starts)[
367-
:, None
368-
],
369-
],
370-
]
371-
)
353+
ovids1 = np.concatenate([
354+
np.repeat(ids1[match_2in1_mask], match_2in1_ends - match_2in1_starts),
355+
ids1[arange_multi(match_1in2_starts, match_1in2_ends)],
356+
])
357+
ovids2 = np.concatenate([
358+
ids2[arange_multi(match_2in1_starts, match_2in1_ends)],
359+
np.repeat(ids2[match_1in2_mask], match_1in2_ends - match_1in2_starts),
360+
])
372361

373362
if sort:
374-
# Sort overlaps according to the 1st
375-
overlap_ids = overlap_ids[np.lexsort([overlap_ids[:, 1], overlap_ids[:, 0]])]
363+
idx = np.lexsort([ovids2, ovids1])
364+
ovids1 = ovids1[idx]
365+
ovids2 = ovids2[idx]
376366

377-
return overlap_ids
367+
return ovids1, ovids2
378368

379369

380370
def overlap_intervals_outer(starts1, ends1, starts2, ends2, closed=False):
@@ -393,25 +383,25 @@ def overlap_intervals_outer(starts1, ends1, starts2, ends2, closed=False):
393383
394384
Returns
395385
-------
396-
overlap_ids : numpy.ndarray
397-
An Nx2 array containing the indices of pairs of overlapping intervals.
398-
The 1st column contains ids from the 1st set, the 2nd column has ids
399-
from the 2nd set.
386+
ovids1, ovids2 : numpy.ndarray
387+
Two 1D arrays containing the indices of pairs of overlapping intervals.
388+
The 1st contains ids from the 1st set, the 2nd has ids from the 2nd set.
400389
401390
no_overlap_ids1, no_overlap_ids2 : numpy.ndarray
402391
Two 1D arrays containing the indices of intervals in sets 1 and 2
403392
respectively that do not overlap with any interval in the other set.
404-
405393
"""
406-
407-
ovids = overlap_intervals(starts1, ends1, starts2, ends2, closed=closed)
408-
no_overlap_ids1 = np.where(
409-
np.bincount(ovids[:, 0], minlength=starts1.shape[0]) == 0
410-
)[0]
411-
no_overlap_ids2 = np.where(
412-
np.bincount(ovids[:, 1], minlength=starts2.shape[0]) == 0
413-
)[0]
414-
return ovids, no_overlap_ids1, no_overlap_ids2
394+
n1, n2 = len(starts1), len(starts2)
395+
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2, closed=closed)
396+
if n1 > 0:
397+
no_overlap_ids1 = np.setdiff1d(np.arange(len(starts1)), ovids1)
398+
else:
399+
no_overlap_ids1 = np.array([], dtype=int)
400+
if n2 > 0:
401+
no_overlap_ids2 = np.setdiff1d(np.arange(len(starts2)), ovids2)
402+
else:
403+
no_overlap_ids2 = np.array([], dtype=int)
404+
return ovids1, ovids2, no_overlap_ids1, no_overlap_ids2
415405

416406

417407
def merge_intervals(starts, ends, min_dist=0):
@@ -532,14 +522,12 @@ def _closest_intervals_nooverlap(
532522
533523
Returns
534524
-------
535-
ids: numpy.ndarray
536-
One Nx2 array containing the indices of pairs of closest intervals,
525+
ids1, ids2: numpy.ndarray
526+
Two arrays containing the indices of pairs of closest intervals,
537527
reported for the neighbors in specified direction (by genomic
538-
coordinate). The two columns are the inteval ids from set 1, ids of
528+
coordinate). The two arrays are the inteval ids from set 1, ids of
539529
the closest intevals from set 2.
540-
541530
"""
542-
543531
for vec in [starts1, ends1, starts2, ends2]:
544532
if isinstance(vec, pd.Series):
545533
warnings.warn(
@@ -556,7 +544,8 @@ def _closest_intervals_nooverlap(
556544
n1 = starts1.shape[0]
557545
n2 = starts2.shape[0]
558546

559-
ids = np.zeros((0, 2), dtype=int)
547+
ids1 = np.array([], dtype=int)
548+
ids2 = np.array([], dtype=int)
560549

561550
if k > 0 and direction == "left":
562551
if tie_arr is None:
@@ -573,14 +562,8 @@ def _closest_intervals_nooverlap(
573562
int1_ids = np.repeat(np.arange(n1), left_closest_endidx - left_closest_startidx)
574563
int2_sorted_ids = arange_multi(left_closest_startidx, left_closest_endidx)
575564

576-
ids = np.vstack(
577-
[
578-
int1_ids,
579-
ids2_endsorted[int2_sorted_ids],
580-
# ends2_sorted[int2_sorted_ids] - starts1[int1_ids],
581-
# arange_multi(left_closest_startidx - left_closest_endidx, 0)
582-
]
583-
).T
565+
ids1 = int1_ids
566+
ids2 = ids2_endsorted[int2_sorted_ids]
584567

585568
elif k > 0 and direction == "right":
586569
if tie_arr is None:
@@ -598,17 +581,11 @@ def _closest_intervals_nooverlap(
598581
np.arange(n1), right_closest_endidx - right_closest_startidx
599582
)
600583
int2_sorted_ids = arange_multi(right_closest_startidx, right_closest_endidx)
601-
ids = np.vstack(
602-
[
603-
int1_ids,
604-
ids2_startsorted[int2_sorted_ids],
605-
# starts2_sorted[int2_sorted_ids] - ends1[int1_ids],
606-
# arange_multi(1, right_closest_endidx -
607-
# right_closest_startidx + 1)
608-
]
609-
).T
610584

611-
return ids
585+
ids1 = int1_ids
586+
ids2 = ids2_startsorted[int2_sorted_ids]
587+
588+
return ids1, ids2
612589

613590

614591
def closest_intervals(
@@ -621,10 +598,11 @@ def closest_intervals(
621598
ignore_overlaps=False,
622599
ignore_upstream=False,
623600
ignore_downstream=False,
624-
direction=None,
601+
along=None,
625602
):
626603
"""
627-
For every interval in set 1, return the indices of k closest intervals from set 2.
604+
For every interval in set 1, return the indices of k closest intervals
605+
from set 2.
628606
629607
Parameters
630608
----------
@@ -637,127 +615,135 @@ def closest_intervals(
637615
The number of neighbors to report.
638616
639617
tie_arr : numpy.ndarray or None
640-
Extra data describing intervals in set 2 to break ties when multiple intervals
641-
are located at the same distance. Intervals with *lower* tie_arr values will
642-
be given priority.
618+
Extra data describing intervals in set 2 to break ties when multiple
619+
intervals are located at the same distance. Intervals with *lower*
620+
tie_arr values will be given priority.
643621
644622
ignore_overlaps : bool
645623
If True, ignore set 2 intervals that overlap with set 1 intervals.
646624
647625
ignore_upstream, ignore_downstream : bool
648626
If True, ignore set 2 intervals upstream/downstream of set 1 intervals.
649627
650-
direction : numpy.ndarray with dtype bool or None
651-
Strand vector to define the upstream/downstream orientation of the intervals.
628+
along : numpy.ndarray with dtype bool or None
629+
Strand vector to define the upstream/downstream orientation of the
630+
intervals.
652631
653632
Returns
654633
-------
655-
closest_ids : numpy.ndarray
656-
An Nx2 array containing the indices of pairs of closest intervals.
657-
The 1st column contains ids from the 1st set, the 2nd column has ids
634+
closest_ids1, closest_ids2 : numpy.ndarray
635+
Two arrays containing the indices of pairs of closest intervals.
636+
The 1st array contains ids from the 1st set, the 2nd array has ids
658637
from the 2nd set.
659-
660638
"""
661-
662-
# Get overlapping intervals:
639+
# Get overlaps
663640
if ignore_overlaps:
664-
overlap_ids = np.zeros((0, 2), dtype=int)
641+
ovids1, ovids2 = np.array([], dtype=int), np.array([], dtype=int)
665642
elif (starts2 is None) and (ends2 is None):
666643
starts2, ends2 = starts1, ends1
667-
overlap_ids = overlap_intervals(starts1, ends1, starts2, ends2)
668-
overlap_ids = overlap_ids[overlap_ids[:, 0] != overlap_ids[:, 1]]
644+
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2)
645+
mask = ovids1 != ovids2
646+
ovids1 = ovids1[mask]
647+
ovids2 = ovids2[mask]
669648
else:
670-
overlap_ids = overlap_intervals(starts1, ends1, starts2, ends2)
649+
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2)
671650

672-
# Get non-overlapping intervals:
651+
# Get non-overlapping nearest neighbors
673652
n = len(starts1)
674-
all_ids = np.arange(n)
675-
676-
# + directed intervals
677-
ids_left_upstream = _closest_intervals_nooverlap(
678-
starts1[direction],
679-
ends1[direction],
653+
if along is None:
654+
along = np.ones(n, dtype=bool)
655+
656+
# + stranded intervals
657+
pos_starts1, pos_ends1 = starts1[along], ends1[along]
658+
pos_up1, pos_up2 = _closest_intervals_nooverlap(
659+
pos_starts1,
660+
pos_ends1,
680661
starts2,
681662
ends2,
682663
direction="left",
683664
tie_arr=tie_arr,
684665
k=0 if ignore_upstream else k,
685666
)
686-
ids_right_downstream = _closest_intervals_nooverlap(
687-
starts1[direction],
688-
ends1[direction],
667+
pos_dn1, pos_dn2 = _closest_intervals_nooverlap(
668+
pos_starts1,
669+
pos_ends1,
689670
starts2,
690671
ends2,
691672
direction="right",
692673
tie_arr=tie_arr,
693674
k=0 if ignore_downstream else k,
694675
)
695-
# - directed intervals
696-
ids_right_upstream = _closest_intervals_nooverlap(
697-
starts1[~direction],
698-
ends1[~direction],
676+
677+
# - stranded intervals
678+
neg_starts1, neg_ends1 = starts1[~along], ends1[~along]
679+
neg_up1, neg_up2 = _closest_intervals_nooverlap(
680+
neg_starts1,
681+
neg_ends1,
699682
starts2,
700683
ends2,
701684
direction="right",
702685
tie_arr=tie_arr,
703686
k=0 if ignore_upstream else k,
704687
)
705-
ids_left_downstream = _closest_intervals_nooverlap(
706-
starts1[~direction],
707-
ends1[~direction],
688+
neg_dn1, neg_dn2 = _closest_intervals_nooverlap(
689+
neg_starts1,
690+
neg_ends1,
708691
starts2,
709692
ends2,
710693
direction="left",
711694
tie_arr=tie_arr,
712695
k=0 if ignore_downstream else k,
713696
)
714697

715-
# Reconstruct original indexes (b/c we split regions by direction above)
716-
ids_left_upstream[:, 0] = all_ids[direction][ids_left_upstream[:, 0]]
717-
ids_right_downstream[:, 0] = all_ids[direction][ids_right_downstream[:, 0]]
718-
ids_left_downstream[:, 0] = all_ids[~direction][ids_left_downstream[:, 0]]
719-
ids_right_upstream[:, 0] = all_ids[~direction][ids_right_upstream[:, 0]]
698+
# Reconstruct original indices (b/c we split ranges by strand above)
699+
pos_ids = np.where(along)[0]
700+
neg_ids = np.where(~along)[0]
701+
pos_up1 = pos_ids[pos_up1]
702+
pos_dn1 = pos_ids[pos_dn1]
703+
neg_dn1 = neg_ids[neg_dn1]
704+
neg_up1 = neg_ids[neg_up1]
720705

721-
left_ids = np.concatenate([ids_left_upstream, ids_left_downstream])
722-
right_ids = np.concatenate([ids_right_upstream, ids_right_downstream])
706+
# Combine by absolute search direction
707+
left_ids1 = np.concatenate([pos_up1, neg_dn1])
708+
left_ids2 = np.concatenate([pos_up2, neg_dn2])
709+
right_ids1 = np.concatenate([neg_up1, pos_dn1])
710+
right_ids2 = np.concatenate([neg_up2, pos_dn2])
723711

724712
# Increase the distance by 1 to distinguish between overlapping
725713
# and non-overlapping set 2 intervals.
726-
left_dists = starts1[left_ids[:, 0]] - ends2[left_ids[:, 1]] + 1
727-
right_dists = starts2[right_ids[:, 1]] - ends1[right_ids[:, 0]] + 1
728-
729-
closest_ids = np.vstack([left_ids, right_ids, overlap_ids])
730-
closest_dists = np.concatenate(
731-
[left_dists, right_dists, np.zeros(overlap_ids.shape[0])]
714+
left_dists = starts1[left_ids1] - ends2[left_ids2] + 1
715+
right_dists = starts2[right_ids2] - ends1[right_ids1] + 1
716+
717+
# Combine the results
718+
events1 = np.concatenate([left_ids1, right_ids1, ovids1])
719+
events2 = np.concatenate([left_ids2, right_ids2, ovids2])
720+
dists = np.concatenate(
721+
[left_dists, right_dists, np.zeros(ovids1.shape[0])]
732722
)
733723

734-
if len(closest_ids) == 0:
735-
return np.empty((0, 2), dtype=int)
724+
if len(events1) == 0:
725+
return np.array([], dtype=int), np.array([], dtype=int)
736726

737727
# Sort by distance to set 1 intervals and, if present, by the tie-breaking
738728
# data array.
739729
if tie_arr is None:
740-
order = np.lexsort([closest_ids[:, 1], closest_dists, closest_ids[:, 0]])
730+
order = np.lexsort([events2, dists, events1])
741731
else:
742-
order = np.lexsort(
743-
[closest_ids[:, 1], tie_arr, closest_dists, closest_ids[:, 0]]
744-
)
745-
746-
closest_ids = closest_ids[order, :2]
747-
748-
# For each set 1 interval, select up to k closest neighbours.
749-
interval1_run_border_mask = closest_ids[:-1, 0] != closest_ids[1:, 0]
750-
interval1_run_borders = np.where(np.r_[True, interval1_run_border_mask, True])[0]
751-
interval1_run_starts = interval1_run_borders[:-1]
752-
interval1_run_ends = interval1_run_borders[1:]
753-
closest_ids = closest_ids[
754-
arange_multi(
755-
interval1_run_starts,
756-
lengths=np.minimum(k, interval1_run_ends - interval1_run_starts),
757-
)
758-
]
732+
order = np.lexsort([events2, tie_arr, dists, events1])
733+
events1 = events1[order]
734+
events2 = events2[order]
735+
736+
# Prune the results to the k nearest neighbors
737+
# For each sorted run of set 1 intervals, select up to k entries
738+
run_borders = np.where(np.r_[True, events1[:-1] != events1[1:], True])[0]
739+
run_starts = run_borders[:-1]
740+
run_ends = run_borders[1:]
741+
idx = arange_multi(
742+
run_starts,
743+
lengths=np.minimum(k, run_ends - run_starts),
744+
)
759745

760-
return closest_ids
746+
return events1[idx], events2[idx]
761747

762748

763749
def coverage_intervals_rle(starts, ends, weights=None):

0 commit comments

Comments
 (0)