Skip to content

Commit a8480af

Browse files
Merge pull request #12 from florian-huber/issue_8_0
Improve merging
2 parents 8a5d272 + a1d7b29 commit a8480af

File tree

3 files changed

+113
-65
lines changed

3 files changed

+113
-65
lines changed

sparsestack/StackedSparseArray.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,30 +231,33 @@ def add_dense_matrix(self, matrix: np.ndarray,
231231
"""
232232
if matrix is None:
233233
self.data = np.array([])
234-
elif len(matrix.dtype) > 1: # if structured array
235-
for dtype_name in matrix.dtype.names:
236-
self._add_dense_matrix(matrix[dtype_name],
237-
f"{name}_{dtype_name}",
238-
join_type)
239234
else:
240235
self._add_dense_matrix(matrix, name, join_type)
241236

242237
def _add_dense_matrix(self, matrix, name, join_type):
243-
if matrix.dtype.type == np.void:
244-
input_dtype = matrix.dtype[0]
245-
else:
246-
input_dtype = matrix.dtype
238+
def get_dtype(data):
239+
if data.dtype.type == np.void:
240+
return data.dtype[0]
241+
return data.dtype
247242

248243
# Handle 1D arrays
249244
if matrix.ndim == 1:
250245
matrix = matrix.reshape(-1, 1)
251246

247+
# Handle structured arrays > 1 dimension
248+
if len(matrix.dtype) > 1:
249+
dtype_data = [(f"{name}_{dtype_name}", get_dtype(matrix[dtype_name])) for dtype_name in matrix.dtype.names]
250+
else:
251+
dtype_data = [(name, get_dtype(matrix))]
252+
253+
252254
if self.shape[2] == 0 or (self.shape[2] == 1 and name in self.score_names):
253255
# Add first (sparse) array of scores
254256
(idx_row, idx_col) = np.where(matrix)
255257
self.row = idx_row
256258
self.col = idx_col
257-
self.data = np.array(matrix[idx_row, idx_col], dtype=[(name, input_dtype)])
259+
260+
self.data = np.array(matrix[idx_row, idx_col], dtype=dtype_data)
258261
else:
259262
# Add new stack of scores
260263
(idx_row, idx_col) = np.where(matrix)

sparsestack/utils.py

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,38 @@ def _join_arrays(row1, col1, data1,
3131
#pylint: disable=too-many-arguments
3232
#pylint: disable=too-many-locals
3333

34+
idx1 = np.lexsort((col1, row1))
35+
idx2 = np.lexsort((col2, row2))
3436
# join types
3537
if join_type == "left":
36-
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
38+
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
39+
idx1, idx2, join_type="inner")
3740
data_join = set_and_fill_new_array(data1, data2, name,
3841
np.arange(0, len(row1)), np.arange(0, len(row1)),
3942
idx_inner_right, idx_inner_left,
4043
len(row1))
4144
return row1, col1, data_join
4245
if join_type == "right":
43-
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
46+
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
47+
idx1, idx2, join_type="inner")
4448
data_join = set_and_fill_new_array(data1, data2, name,
4549
idx_inner_left, idx_inner_right,
4650
np.arange(0, len(row2)), np.arange(0, len(row2)),
4751
len(row2))
4852
return row2, col2, data_join
4953
if join_type == "inner":
50-
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
54+
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
55+
idx1, idx2, join_type="inner")
5156
data_join = set_and_fill_new_array(data1, data2, name,
5257
idx_inner_left, np.arange(0, len(idx_inner_left)),
5358
idx_inner_right, np.arange(0, len(idx_inner_left)),
5459
len(idx_inner_left))
5560
return row1[idx_inner_left], col1[idx_inner_left], data_join
5661
if join_type == "outer":
57-
idx_left, idx_left_new, idx_right, idx_right_new, row_new, col_new = get_idx_outer(row1, col1, row2, col2)
62+
idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new = get_idx_outer(
63+
row1, col1, row2, col2,
64+
idx1, idx2
65+
)
5866
data_join = set_and_fill_new_array(data1, data2, name,
5967
idx_left, idx_left_new, idx_right, idx_right_new,
6068
len(row_new))
@@ -68,6 +76,7 @@ def set_and_fill_new_array(data1, data2, name,
6876
"""Create new structured numpy array and fill with data1 and data2.
6977
"""
7078
#pylint: disable=too-many-arguments
79+
7180
new_dtype = [(dname, d[0]) for dname, d in data1.dtype.fields.items()]
7281
if data2.dtype.names is None:
7382
new_dtype += [(name, data2.dtype)]
@@ -92,69 +101,104 @@ def set_and_fill_new_array(data1, data2, name,
92101

93102

94103
@numba.jit(nopython=True)
95-
def get_idx_inner_brute_force(left_row, left_col, right_row, right_col):
96-
#Get indexes for entries for a inner join.
97-
idx_inner_left = []
98-
idx_inner_right = []
99-
for i, right_row_id in enumerate(right_row):
100-
if right_row_id in left_row:
101-
idx = np.where((left_row == right_row_id)
102-
& (left_col == right_col[i]))[0]
103-
if len(idx) > 0:
104-
idx_inner_left.append(idx[0])
105-
idx_inner_right.append(i)
106-
return idx_inner_left, idx_inner_right
104+
def get_idx_inner(left_row, left_col, right_row, right_col,
105+
idx1, idx2):
106+
"""Get current and new indices for inner merge.
107107
108+
idx1, idx2
109+
Numpy array of pre-sorted (np.lexsort) indices for left/right arrays.
110+
"""
111+
#pylint: disable=too-many-arguments
112+
#pylint: disable=too-many-locals
108113

109-
@numba.jit(nopython=True)
110-
def get_idx(left_row, left_col, right_row, right_col,
111-
join_type="left"):
112-
list1 = list(zip(left_row, left_col))
113-
list2 = list(zip(right_row, right_col))
114-
if join_type == "left":
115-
uniques = set(list1)
116-
elif join_type == "right":
117-
uniques = set(list2)
118-
elif join_type == "inner":
119-
uniques = set(list1).intersection(set(list2))
120-
#elif join_type == "outer":
121-
# uniques = set(list1).union(set(list2))
122-
else:
123-
raise ValueError("Unknown join_type")
124-
uniques = sorted(list(uniques))
125114
idx_left = []
115+
idx_left_new = []
126116
idx_right = []
127-
for (r, c) in uniques:
128-
i_left = np.where((left_row == r) & (left_col == c))[0]
129-
if len(i_left) > 0:
130-
idx_left.append(i_left[0])
131-
i_right = np.where((right_row == r) & (right_col == c))[0]
132-
if len(i_right) > 0:
133-
idx_right.append(i_right[0])
134-
return idx_left, idx_right
117+
idx_right_new = []
118+
row_new = []
119+
col_new = []
120+
low = 0
121+
counter = 0
122+
for i in idx1:
123+
for j in idx2[low:]:
124+
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
125+
idx_left.append(i)
126+
idx_left_new.append(counter)
127+
idx_right.append(j)
128+
idx_right_new.append(counter)
129+
row_new.append(left_row[i])
130+
col_new.append(left_col[i])
131+
counter += 1
132+
if left_row[i] > right_row[j]:
133+
low = j
134+
if left_row[i] < right_row[j]:
135+
break
136+
return idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new
135137

136138

137139
@numba.jit(nopython=True)
138-
def get_idx_outer(left_row, left_col, right_row, right_col):
140+
def get_idx_outer(left_row, left_col, right_row, right_col,
141+
idx1, idx2):
142+
"""Get current and new indices for outer merge.
143+
144+
idx1, idx2
145+
Numpy array of pre-sorted (np.lexsort) indices for left/right arrays.
146+
"""
147+
#pylint: disable=too-many-arguments
139148
#pylint: disable=too-many-locals
140-
uniques = set(zip(left_row, left_col)).union(set(zip(right_row, right_col)))
141-
uniques = sorted(list(uniques))
142149

143150
idx_left = []
144151
idx_left_new = []
145152
idx_right = []
146153
idx_right_new = []
147154
row_new = []
148155
col_new = []
149-
for i, (r, c) in enumerate(uniques):
150-
row_new.append(r)
151-
col_new.append(c)
152-
i_left = np.where((left_row == r) & (left_col == c))[0]
153-
if len(i_left) > 0:
154-
idx_left.append(i_left[0])
155-
idx_left_new.append(i)
156-
i_right = np.where((right_row == r) & (right_col == c))[0]
157-
if len(i_right) > 0:
158-
idx_right.append(i_right[0])
159-
idx_right_new.append(i)
160-
return idx_left, idx_left_new, idx_right, idx_right_new, row_new, col_new
156+
157+
right_in_inner = []
158+
low = 0
159+
counter = 0
160+
for i in idx1:
161+
current_match = False
162+
for j in idx2[low:]:
163+
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
164+
right_in_inner.append(j)
165+
current_match = True
166+
if left_row[i] > right_row[j]:
167+
low = j
168+
if left_row[i] < right_row[j]:
169+
break
170+
if current_match:
171+
x = right_in_inner[-1]
172+
idx_left.append(i)
173+
idx_left_new.append(counter)
174+
idx_right.append(x)
175+
idx_right_new.append(counter)
176+
row_new.append(left_row[i])
177+
col_new.append(left_col[i])
178+
counter += 1
179+
else:
180+
idx_left.append(i)
181+
idx_left_new.append(counter)
182+
row_new.append(left_row[i])
183+
col_new.append(left_col[i])
184+
counter += 1
185+
186+
for j in set(idx2).difference(set(right_in_inner)):
187+
idx_right.append(j)
188+
idx_right_new.append(counter)
189+
row_new.append(right_row[j])
190+
col_new.append(right_col[j])
191+
counter += 1
192+
return idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new
193+
194+
195+
def get_idx(left_row, left_col, right_row, right_col, idx1, idx2,
196+
join_type="left"):
197+
#pylint: disable=too-many-arguments
198+
if join_type == "inner":
199+
return get_idx_inner(left_row, left_col, right_row, right_col,
200+
idx1, idx2)
201+
if join_type == "outer":
202+
return get_idx_outer(left_row, left_col, right_row, right_col,
203+
idx1, idx2)
204+
raise ValueError("Unknown join_type")

tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_join_arrays(row2, col2):
2323
assert np.allclose([x[0] for x in c], [0, 1, 2, 3, 4])
2424
assert np.allclose([x[1] for x in c], [0, 0, 10, 0, 20])
2525

26+
2627
@pytest.mark.parametrize("join_type, expected_data, expected_row", [
2728
["left", np.array([[0, 0], [1, 0], [2, 2], [4, 0], [5, 5]]), np.array([0, 1, 2, 4, 5])],
2829
["right", np.array([[2, 2],[0, 3], [5, 5], [0, 6], [0, 7],]), np.array([2, 3, 5, 6, 7])],

0 commit comments

Comments
 (0)