Skip to content

Commit

Permalink
feat: upcast schemas if needed during set ops
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Jan 26, 2025
1 parent 0044845 commit d500718
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 13 deletions.
30 changes: 30 additions & 0 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@ def test_union(backend, union_subsets, distinct):
backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_unified_schemas(backend, con, distinct):
a = con.table("functional_alltypes").select(
"id",
i="tinyint_col",
s=_.string_col.cast("!string"),
)
b = con.table("functional_alltypes").select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)

expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
result = expr.execute()

expected = (
pd.concat([a.execute(), b.execute()], axis=0)
.sort_values(["id", "i", "s"])
.reset_index(drop=True)
)
if distinct:
expected = expected.drop_duplicates(["id", "i", "s"])

backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_union_mixed_distinct(backend, union_subsets):
(a, b, c), (da, db, dc) = union_subsets
Expand Down
51 changes: 38 additions & 13 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,44 @@ class Set(Relation):
values = FrozenOrderedDict()

def __init__(self, left, right, **kwargs):
err_msg = "Table schemas must be equal for set operations."
try:
missing_from_left = right.schema - left.schema
missing_from_right = left.schema - right.schema
except ConflictingValuesError as e:
raise RelationError(err_msg + "\n" + str(e)) from e
if missing_from_left or missing_from_right:
msgs = [err_msg]
if missing_from_left:
msgs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
msgs.append(f"Columns missing from the right:\n{missing_from_right}.")
raise RelationError("\n".join(msgs))
# TODO: hoist this up into the user facing API so we can see
# all the tables at once and give a better error message
errs = ["Table schemas must be unifiable for set operations."]
missing_from_left = set(right.schema.names) - set(left.schema.names)
missing_from_right = set(left.schema.names) - set(right.schema.names)
if missing_from_left:
errs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
errs.append(f"Columns missing from the right:\n{missing_from_right}.")
if len(errs) > 1:
raise RelationError("\n".join(errs))

upcasts = {}
for name in left.schema.names:
ltype, rtype = left.schema[name], right.schema[name]
try:
unified_dt = dt.highest_precedence([ltype, rtype])
except IbisTypeError:
errs.append(f"Unable to find a common dtype for column {name}")
errs.append(f"Left dtype: {ltype!s}")
errs.append(f"Right dtype: {rtype!s}")
if unified_dt != ltype or unified_dt != rtype:
upcasts[name] = unified_dt
if len(errs) > 1:
raise ConflictingValuesError("\n".join(errs))

if upcasts:
from ibis.expr.operations.generic import Cast

def get_new_val(relation, name):
if name not in upcasts:
return Field(relation, name)
return Cast(Field(relation, name), upcasts[name])

lcols = {name: get_new_val(left, name) for name in left.schema.names}
rcols = {name: get_new_val(right, name) for name in left.schema.names}
left = Project(left, lcols)
right = Project(right, rcols)

if left.schema.names != right.schema.names:
# rewrite so that both sides have the columns in the same order making it
Expand Down

0 comments on commit d500718

Please sign in to comment.