Skip to content

Commit a62f4fa

Browse files
committed
test: refactor out get_clusters() into tests.util for external use
1 parent e18c083 commit a62f4fa

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

mismo/cluster/test/test_connected_components.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Any
4-
53
import ibis
64
from ibis.expr import types as ir
75
import pandas as pd
86
import pytest
97

108
from mismo.cluster import connected_components
9+
from mismo.tests.util import get_clusters
1110

1211

1312
@pytest.fixture(params=["component", "cluster"])
@@ -145,14 +144,5 @@ def test_cc_max_iterations(table_factory):
145144

146145
def _labels_to_clusters(
147146
labels: ir.Table, label_as: str = "component"
148-
) -> set[frozenset[Any]]:
149-
labels = labels.rename(component=label_as)
150-
assert labels.component.type() == ibis.dtype("int64")
151-
df = labels.to_pandas()
152-
cid_to_rid = {c: set() for c in set(df.component)}
153-
for row in df.itertuples():
154-
record_id = row.record_id
155-
if isinstance(record_id, dict):
156-
record_id = tuple(record_id.values())
157-
cid_to_rid[row.component].add(record_id)
158-
return {frozenset(records) for records in cid_to_rid.values()}
147+
) -> set[frozenset[int]]:
148+
return get_clusters(labels[label_as], label=labels.record_id)

mismo/tests/util.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Literal
44

5+
import ibis
56
from ibis.expr import types as ir
67
import pandas as pd
78
import pytest
@@ -50,3 +51,30 @@ def make_float_comparable(val):
5051
if isinstance(val, float) and not pd.isna(val):
5152
return pytest.approx(val, rel=1e-4)
5253
return val
54+
55+
56+
def get_clusters(
57+
cluster_id: ibis.Column, *, label: ibis.Column | None = None
58+
) -> set[frozenset]:
59+
"""Convert a label column into a set of clusters.
60+
61+
Say you have a table of records, and one of the columns
62+
is an ID that groups records together.
63+
This function will return a set of frozensets, where each
64+
frozenset is a cluster of record IDs.
65+
66+
You can either provide a label column to act as the record IDs,
67+
or if not given, it will use `ibis.row_number()`.
68+
"""
69+
if label is None:
70+
label = ibis.row_number()
71+
clusters = label.collect().over(group_by=cluster_id)
72+
73+
def make_hashable(cluster):
74+
for record_id in cluster:
75+
if isinstance(record_id, dict):
76+
yield tuple(record_id.values())
77+
else:
78+
yield record_id
79+
80+
return {frozenset(make_hashable(c)) for c in clusters.execute()}

0 commit comments

Comments
 (0)