File tree Expand file tree Collapse file tree 2 files changed +31
-13
lines changed Expand file tree Collapse file tree 2 files changed +31
-13
lines changed Original file line number Diff line number Diff line change 1
1
from __future__ import annotations
2
2
3
- from typing import Any
4
-
5
3
import ibis
6
4
from ibis .expr import types as ir
7
5
import pandas as pd
8
6
import pytest
9
7
10
8
from mismo .cluster import connected_components
9
+ from mismo .tests .util import get_clusters
11
10
12
11
13
12
@pytest .fixture (params = ["component" , "cluster" ])
@@ -145,14 +144,5 @@ def test_cc_max_iterations(table_factory):
145
144
146
145
def _labels_to_clusters (
147
146
labels : ir .Table , label_as : str = "component"
148
- ) -> set [frozenset [Any ]]:
149
- labels = labels .rename (component = label_as )
150
- assert labels .component .type () == ibis .dtype ("int64" )
151
- df = labels .to_pandas ()
152
- cid_to_rid = {c : set () for c in set (df .component )}
153
- for row in df .itertuples ():
154
- record_id = row .record_id
155
- if isinstance (record_id , dict ):
156
- record_id = tuple (record_id .values ())
157
- cid_to_rid [row .component ].add (record_id )
158
- return {frozenset (records ) for records in cid_to_rid .values ()}
147
+ ) -> set [frozenset [int ]]:
148
+ return get_clusters (labels [label_as ], label = labels .record_id )
Original file line number Diff line number Diff line change 2
2
3
3
from typing import Literal
4
4
5
+ import ibis
5
6
from ibis .expr import types as ir
6
7
import pandas as pd
7
8
import pytest
@@ -50,3 +51,30 @@ def make_float_comparable(val):
50
51
if isinstance (val , float ) and not pd .isna (val ):
51
52
return pytest .approx (val , rel = 1e-4 )
52
53
return val
54
+
55
+
56
+ def get_clusters (
57
+ cluster_id : ibis .Column , * , label : ibis .Column | None = None
58
+ ) -> set [frozenset ]:
59
+ """Convert a label column into a set of clusters.
60
+
61
+ Say you have a table of records, and one of the columns
62
+ is an ID that groups records together.
63
+ This function will return a set of frozensets, where each
64
+ frozenset is a cluster of record IDs.
65
+
66
+ You can either provide a label column to act as the record IDs,
67
+ or if not given, it will use `ibis.row_number()`.
68
+ """
69
+ if label is None :
70
+ label = ibis .row_number ()
71
+ clusters = label .collect ().over (group_by = cluster_id )
72
+
73
+ def make_hashable (cluster ):
74
+ for record_id in cluster :
75
+ if isinstance (record_id , dict ):
76
+ yield tuple (record_id .values ())
77
+ else :
78
+ yield record_id
79
+
80
+ return {frozenset (make_hashable (c )) for c in clusters .execute ()}
You can’t perform that action at this time.
0 commit comments