Skip to content

Commit

Permalink
refactor: consolidate util.ensure_join_suffixed into join_ensure_named
Browse files Browse the repository at this point in the history
This makes it more general purpose too.
  • Loading branch information
NickCrews committed Feb 3, 2025
1 parent 0526fb7 commit 7f99296
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 43 deletions.
87 changes: 61 additions & 26 deletions mismo/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,32 +433,67 @@ def _warn():
f()


def ensure_join_suffixed(
original_left_cols: Iterable[str],
original_right_cols: Iterable[str],
t: ir.Table,
lsuffix: str = "_l",
rsuffix: str = "_r",
) -> ir.Table:
"""Ensure that all columns in `t` have a "_l" or "_r" suffix."""
lc = set(original_left_cols)
rc = set(original_right_cols)
just_left = lc - rc
just_right = rc - lc
m = {c + lsuffix: c for c in just_left} | {c + rsuffix: c for c in just_right}
t = t.rename(m)

# If the condition is an equality condition, like `left.name == right.name`,
# then since we are doing an inner join ibis doesn't add suffixes to these
# columns. So we need duplicate these columns and add suffixes.
un_suffixed = [
c for c in t.columns if not c.endswith(lsuffix) and not c.endswith(rsuffix)
]
m = {c + lsuffix: _[c] for c in un_suffixed} | {
c + rsuffix: _[c] for c in un_suffixed
}
t = t.mutate(**m).drop(*un_suffixed)
return t
def join_ensure_named(
left: ir.Table,
right: ir.Table,
predicates: str
| Sequence[
str
| ir.BooleanColumn
| tuple[str | ir.Column | Deferred, str | ir.Column | Deferred]
| bool
] = (),
how: str = "inner",
*,
lname: str = "{name}",
rname: str = "{name}_right",
):
"""
Ibis.join, but AWLAYS apply lname and rname to all columns, not just on conflict.
"""
joined = ibis.join(
left,
right,
how=how,
lname=lname,
rname=rname,
predicates=predicates,
)

def _rename(spec: str, name: str):
return spec.format(name=name)

selections = []
for col in left.columns:
new_name = _rename(lname, col)
if new_name in joined.columns:
selections.append((new_name, new_name))
else:
assert col in joined.columns
selections.append((new_name, col))
for col in right.columns:
new_name = _rename(rname, col)
if new_name in joined.columns:
selections.append((new_name, new_name))
else:
assert col in joined.columns
selections.append((new_name, col))

# check for dupe output columns
by_new_name = {}
for new_name, old_name in selections:
if new_name not in by_new_name:
by_new_name[new_name] = [old_name]
else:
by_new_name[new_name].append(old_name)
dupes = []
for new_name, old_names in by_new_name.items():
if len(old_names) > 1:
dupes.append(f"Column {new_name} is produced by {old_names}")
if dupes:
raise ValueError("\n".join(dupes))

return joined.select(**{new_name: old_name for new_name, old_name in selections})


def check_schemas_equal(a: ibis.Schema | ibis.Table, b: ibis.Schema | ibis.Table):
Expand Down
3 changes: 1 addition & 2 deletions mismo/block/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def block(
pred = pred & (left.record_id < right.record_id)

_sql_analyze.check_join_algorithm(left, right, pred, on_slow=on_slow)
j = ibis.join(left, right, pred, lname="{name}_l", rname="{name}_r")
j = _util.ensure_join_suffixed(left.columns, right.columns, j)
j = _util.join_ensure_named(left, right, pred, lname="{name}_l", rname="{name}_r")
j = fix_blocked_column_order(j)
return j

Expand Down
46 changes: 46 additions & 0 deletions mismo/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,49 @@ def test_optional_import():

assert False, "should not get here"
assert "foo" in str(excinfo.value)


def test_join_ensure_named():
a = ibis.table({"a": "int64", "b": "string"}, name="t")
b = ibis.table({"a": "int64", "c": "string"}, name="u")
x = ibis.table({"y": "int64", "z": "string"}, name="u")

def _schema(t):
return {k: str(v) for k, v in t.schema().items()}

assert {
"a": "int64",
"b": "string",
"a_right": "int64",
"c_right": "string",
} == _schema(_util.join_ensure_named(a, b, "a"))
assert {
"a": "int64",
"b": "string",
"a_right": "int64",
"c_right": "string",
} == _schema(_util.join_ensure_named(a, b, a.a > b.a))
assert {
"a": "int64",
"b": "string",
"y_right": "int64",
"z_right": "string",
} == _schema(_util.join_ensure_named(a, x, True))

assert {
"a_l": "int64",
"b_l": "string",
"a_right": "int64",
"c_right": "string",
} == _schema(_util.join_ensure_named(a, b, a.a > b.a, lname="{name}_l"))
assert {
"a_l": "int64",
"b_l": "string",
"a": "int64",
"c": "string",
} == _schema(
_util.join_ensure_named(a, b, a.a > b.a, lname="{name}_l", rname="{name}")
)

with pytest.raises(ValueError):
_util.join_ensure_named(a, b, a.a > b.a, lname="{name}_x", rname="{name}_x")
11 changes: 4 additions & 7 deletions mismo/types/_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,11 @@ def from_predicates(
left = left.mutate(ibis.row_number().name("record_id"))
if "record_id" not in right.columns:
right = right.mutate(ibis.row_number().name("record_id"))
links = ibis.join(
left,
right,
predicates,
lname="{name}_l",
rname="{name}_r",
if isinstance(predicates, tuple) and len(predicates) == 2:
predicates = [predicates]
links = _util.join_ensure_named(
left, right, predicates, lname="{name}_l", rname="{name}_r"
)
links = _util.ensure_join_suffixed(left.columns, right.columns, links)
links = links.select("record_id_l", "record_id_r")
return cls(left, right, links)

Expand Down
10 changes: 2 additions & 8 deletions mismo/types/_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,9 @@ def from_tables(
# 1. all the columns in after
# 2. any extra columns in before are tacked on the end
all_columns = (dict(before.schema()) | dict(after.schema())).keys()
joined = ibis.join(
before,
after,
how="inner",
lname="{name}_l",
rname="{name}_r",
predicates=join_on,
joined = _util.join_ensure_named(
before, after, join_on, lname="{name}_l", rname="{name}_r"
)
joined = _util.ensure_join_suffixed(before.columns, after.columns, joined)

def make_diff_col(col: str) -> ir.StructColumn:
d = {}
Expand Down

0 comments on commit 7f99296

Please sign in to comment.