Skip to content

Commit 8198e6d

Browse files
Merge pull request #14 from florian-huber/fix_merge_errors
Fix merge errors
2 parents 225a87b + 5ac4da6 commit 8198e6d

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

sparsestack/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.0'
1+
__version__ = '0.4.1'

sparsestack/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_idx_inner(left_row, left_col, right_row, right_col,
120120
low = 0
121121
counter = 0
122122
for i in idx1:
123-
for j in idx2[low:]:
123+
for count, j in enumerate(idx2[low:]):
124124
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
125125
idx_left.append(i)
126126
idx_left_new.append(counter)
@@ -129,8 +129,9 @@ def get_idx_inner(left_row, left_col, right_row, right_col,
129129
row_new.append(left_row[i])
130130
col_new.append(left_col[i])
131131
counter += 1
132+
low = count
132133
if left_row[i] > right_row[j]:
133-
low = j
134+
low = count
134135
if left_row[i] < right_row[j]:
135136
break
136137
return idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new
@@ -159,12 +160,12 @@ def get_idx_outer(left_row, left_col, right_row, right_col,
159160
counter = 0
160161
for i in idx1:
161162
current_match = False
162-
for j in idx2[low:]:
163+
for count, j in enumerate(idx2[low:]):
163164
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
164165
right_in_inner.append(j)
165166
current_match = True
166167
if left_row[i] > right_row[j]:
167-
low = j
168+
low = count
168169
if left_row[i] < right_row[j]:
169170
break
170171
if current_match:

tests/test_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,44 @@ def test_join_arrays_join_types(join_type, expected_data, expected_row):
4343
join_type=join_type)
4444
assert np.allclose(np.array([[x[0], x[1]] for x in data]), expected_data)
4545
assert np.allclose(row, expected_row)
46+
47+
48+
@pytest.mark.parametrize("join_type, expected_data, expected_row", [
49+
["left", np.array([[0, 0], [1, 0], [2, 2], [4, 0], [5, 5]]), np.array([0, 1, 2, 4, 5])],
50+
["right", np.array([[2, 2],[0, 3], [5, 5], [0, 6], [0, 7],]), np.array([2, 3, 5, 6, 7])],
51+
["inner", np.array([[2, 2], [5, 5]]), np.array([2, 5])],
52+
["outer", np.array([[0, 0], [1, 0], [2, 2], [0, 3], [4, 0], [5, 5], [0, 6], [0, 7]]),
53+
np.array([0, 1, 2, 3, 4, 5, 6, 7])],
54+
])
55+
def test_join_arrays_join_types(join_type, expected_data, expected_row):
56+
row1 = np.array([0, 1, 2, 4, 5])
57+
col1 = np.array([0, 1, 2, 4, 5])
58+
row2 = np.array([7, 5, 3, 6, 2])
59+
col2 = np.array([7, 5, 3, 6, 2])
60+
data1 = np.array(col1, dtype=[("layer1", col1.dtype)])
61+
data2 = np.array(col2, dtype=[("layer2", col2.dtype)])
62+
63+
row, col, data = join_arrays(row1, col1, data1, row2, col2, data2, "test1",
64+
join_type=join_type)
65+
assert np.allclose(np.array([[x[0], x[1]] for x in data]), expected_data)
66+
assert np.allclose(row, expected_row)
67+
68+
69+
@pytest.mark.parametrize("join_type", [
70+
"left", "right", "inner", "outer"
71+
])
72+
def test_join_arrays_larger(join_type):
73+
"""Joining two identical arrays should always give the same result."""
74+
row = np.array([ 0, 1, 2, 3, 4, 5, 6, 6, 15, 7, 7, 8, 7, 9, 7, 10, 7,
75+
11, 8, 8, 9, 8, 10, 8, 11, 9, 9, 10, 9, 11, 10, 10, 11, 11,
76+
12, 12, 13, 13, 14, 15, 16, 17, 18, 19])
77+
78+
col = np.array([ 0, 1, 2, 3, 4, 5, 6, 15, 6, 7, 8, 7, 9, 7, 10, 7, 11,
79+
7, 8, 9, 8, 10, 8, 11, 8, 9, 10, 9, 11, 9, 10, 11, 10, 11,
80+
12, 13, 12, 13, 14, 15, 16, 17, 18, 19])
81+
data1 = np.array(np.arange(0, len(col)), dtype=[("layer1", col.dtype)])
82+
data2 = np.array(np.arange(0, len(col)), dtype=[("layer2", col.dtype)])
83+
84+
row_out, col_out, data_out = join_arrays(row, col, data1, row, col, data2, "test1",
85+
join_type=join_type)
86+
assert np.allclose(sorted(data_out["test1_layer2"]), sorted(np.array([x[0] for x in data2])))

0 commit comments

Comments
 (0)