From 7f99296747d2ac3926936cdaed94742fe29e86a7 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sun, 2 Feb 2025 20:27:22 -0900 Subject: [PATCH] refactor: consolidate util.ensure_join_suffixed into join_ensure_named This makes it more general purpose too. --- mismo/_util.py | 87 ++++++++++++++++++++++++++++------------ mismo/block/_core.py | 3 +- mismo/tests/test_util.py | 46 +++++++++++++++++++++ mismo/types/_linkage.py | 11 ++--- mismo/types/_updates.py | 10 +---- 5 files changed, 114 insertions(+), 43 deletions(-) diff --git a/mismo/_util.py b/mismo/_util.py index 2af569da..09ce7eef 100644 --- a/mismo/_util.py +++ b/mismo/_util.py @@ -433,32 +433,67 @@ def _warn(): f() -def ensure_join_suffixed( - original_left_cols: Iterable[str], - original_right_cols: Iterable[str], - t: ir.Table, - lsuffix: str = "_l", - rsuffix: str = "_r", -) -> ir.Table: - """Ensure that all columns in `t` have a "_l" or "_r" suffix.""" - lc = set(original_left_cols) - rc = set(original_right_cols) - just_left = lc - rc - just_right = rc - lc - m = {c + lsuffix: c for c in just_left} | {c + rsuffix: c for c in just_right} - t = t.rename(m) - - # If the condition is an equality condition, like `left.name == right.name`, - # then since we are doing an inner join ibis doesn't add suffixes to these - # columns. So we need duplicate these columns and add suffixes. - un_suffixed = [ - c for c in t.columns if not c.endswith(lsuffix) and not c.endswith(rsuffix) - ] - m = {c + lsuffix: _[c] for c in un_suffixed} | { - c + rsuffix: _[c] for c in un_suffixed - } - t = t.mutate(**m).drop(*un_suffixed) - return t +def join_ensure_named( + left: ir.Table, + right: ir.Table, + predicates: str + | Sequence[ + str + | ir.BooleanColumn + | tuple[str | ir.Column | Deferred, str | ir.Column | Deferred] + | bool + ] = (), + how: str = "inner", + *, + lname: str = "{name}", + rname: str = "{name}_right", +): + """ + Ibis.join, but AWLAYS apply lname and rname to all columns, not just on conflict. + """ + joined = ibis.join( + left, + right, + how=how, + lname=lname, + rname=rname, + predicates=predicates, + ) + + def _rename(spec: str, name: str): + return spec.format(name=name) + + selections = [] + for col in left.columns: + new_name = _rename(lname, col) + if new_name in joined.columns: + selections.append((new_name, new_name)) + else: + assert col in joined.columns + selections.append((new_name, col)) + for col in right.columns: + new_name = _rename(rname, col) + if new_name in joined.columns: + selections.append((new_name, new_name)) + else: + assert col in joined.columns + selections.append((new_name, col)) + + # check for dupe output columns + by_new_name = {} + for new_name, old_name in selections: + if new_name not in by_new_name: + by_new_name[new_name] = [old_name] + else: + by_new_name[new_name].append(old_name) + dupes = [] + for new_name, old_names in by_new_name.items(): + if len(old_names) > 1: + dupes.append(f"Column {new_name} is produced by {old_names}") + if dupes: + raise ValueError("\n".join(dupes)) + + return joined.select(**{new_name: old_name for new_name, old_name in selections}) def check_schemas_equal(a: ibis.Schema | ibis.Table, b: ibis.Schema | ibis.Table): diff --git a/mismo/block/_core.py b/mismo/block/_core.py index 422ce5f8..5af98f83 100644 --- a/mismo/block/_core.py +++ b/mismo/block/_core.py @@ -124,8 +124,7 @@ def block( pred = pred & (left.record_id < right.record_id) _sql_analyze.check_join_algorithm(left, right, pred, on_slow=on_slow) - j = ibis.join(left, right, pred, lname="{name}_l", rname="{name}_r") - j = _util.ensure_join_suffixed(left.columns, right.columns, j) + j = _util.join_ensure_named(left, right, pred, lname="{name}_l", rname="{name}_r") j = fix_blocked_column_order(j) return j diff --git a/mismo/tests/test_util.py b/mismo/tests/test_util.py index 398f5a0f..1b9c05bd 100644 --- a/mismo/tests/test_util.py +++ b/mismo/tests/test_util.py @@ -98,3 +98,49 @@ def test_optional_import(): assert False, "should not get here" assert "foo" in str(excinfo.value) + + +def test_join_ensure_named(): + a = ibis.table({"a": "int64", "b": "string"}, name="t") + b = ibis.table({"a": "int64", "c": "string"}, name="u") + x = ibis.table({"y": "int64", "z": "string"}, name="u") + + def _schema(t): + return {k: str(v) for k, v in t.schema().items()} + + assert { + "a": "int64", + "b": "string", + "a_right": "int64", + "c_right": "string", + } == _schema(_util.join_ensure_named(a, b, "a")) + assert { + "a": "int64", + "b": "string", + "a_right": "int64", + "c_right": "string", + } == _schema(_util.join_ensure_named(a, b, a.a > b.a)) + assert { + "a": "int64", + "b": "string", + "y_right": "int64", + "z_right": "string", + } == _schema(_util.join_ensure_named(a, x, True)) + + assert { + "a_l": "int64", + "b_l": "string", + "a_right": "int64", + "c_right": "string", + } == _schema(_util.join_ensure_named(a, b, a.a > b.a, lname="{name}_l")) + assert { + "a_l": "int64", + "b_l": "string", + "a": "int64", + "c": "string", + } == _schema( + _util.join_ensure_named(a, b, a.a > b.a, lname="{name}_l", rname="{name}") + ) + + with pytest.raises(ValueError): + _util.join_ensure_named(a, b, a.a > b.a, lname="{name}_x", rname="{name}_x") diff --git a/mismo/types/_linkage.py b/mismo/types/_linkage.py index 763801ca..eb7757c4 100644 --- a/mismo/types/_linkage.py +++ b/mismo/types/_linkage.py @@ -468,14 +468,11 @@ def from_predicates( left = left.mutate(ibis.row_number().name("record_id")) if "record_id" not in right.columns: right = right.mutate(ibis.row_number().name("record_id")) - links = ibis.join( - left, - right, - predicates, - lname="{name}_l", - rname="{name}_r", + if isinstance(predicates, tuple) and len(predicates) == 2: + predicates = [predicates] + links = _util.join_ensure_named( + left, right, predicates, lname="{name}_l", rname="{name}_r" ) - links = _util.ensure_join_suffixed(left.columns, right.columns, links) links = links.select("record_id_l", "record_id_r") return cls(left, right, links) diff --git a/mismo/types/_updates.py b/mismo/types/_updates.py index ba67f31d..3d1bce33 100644 --- a/mismo/types/_updates.py +++ b/mismo/types/_updates.py @@ -178,15 +178,9 @@ def from_tables( # 1. all the columns in after # 2. any extra columns in before are tacked on the end all_columns = (dict(before.schema()) | dict(after.schema())).keys() - joined = ibis.join( - before, - after, - how="inner", - lname="{name}_l", - rname="{name}_r", - predicates=join_on, + joined = _util.join_ensure_named( + before, after, join_on, lname="{name}_l", rname="{name}_r" ) - joined = _util.ensure_join_suffixed(before.columns, after.columns, joined) def make_diff_col(col: str) -> ir.StructColumn: d = {}