diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 699438169e59..c663dcfc8e8c 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -49,6 +49,36 @@ def test_union(backend, union_subsets, distinct): backend.assert_frame_equal(result, expected) +@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"]) +@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) +def test_unified_schemas(backend, con, distinct): + a = con.table("functional_alltypes").select( + "id", + i="tinyint_col", + s=_.string_col.cast("!string"), + ) + b = con.table("functional_alltypes").select( + "id", + i=_.bigint_col + 256, # ensure doesn't fit in a tinyint + s=_.string_col.cast("string"), + ) + + expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s") + assert expr.i.type() == b.i.type() + assert expr.s.type() == b.s.type() + result = expr.execute() + + expected = ( + pd.concat([a.execute(), b.execute()], axis=0) + .sort_values(["id", "i", "s"]) + .reset_index(drop=True) + ) + if distinct: + expected = expected.drop_duplicates(["id", "i", "s"]) + + backend.assert_frame_equal(result, expected, check_dtype=False) + + @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_union_mixed_distinct(backend, union_subsets): (a, b, c), (da, db, dc) = union_subsets diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 90dc400a8ded..3920da3ac33b 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -338,19 +338,44 @@ class Set(Relation): values = FrozenOrderedDict() def __init__(self, left, right, **kwargs): - err_msg = "Table schemas must be equal for set operations." - try: - missing_from_left = right.schema - left.schema - missing_from_right = left.schema - right.schema - except ConflictingValuesError as e: - raise RelationError(err_msg + "\n" + str(e)) from e - if missing_from_left or missing_from_right: - msgs = [err_msg] - if missing_from_left: - msgs.append(f"Columns missing from the left:\n{missing_from_left}.") - if missing_from_right: - msgs.append(f"Columns missing from the right:\n{missing_from_right}.") - raise RelationError("\n".join(msgs)) + # TODO: hoist this up into the user facing API so we can see + # all the tables at once and give a better error message + errs = ["Table schemas must be unifiable for set operations."] + missing_from_left = set(right.schema.names) - set(left.schema.names) + missing_from_right = set(left.schema.names) - set(right.schema.names) + if missing_from_left: + errs.append(f"Columns missing from the left:\n{missing_from_left}.") + if missing_from_right: + errs.append(f"Columns missing from the right:\n{missing_from_right}.") + if len(errs) > 1: + raise RelationError("\n".join(errs)) + + upcasts = {} + for name in left.schema.names: + ltype, rtype = left.schema[name], right.schema[name] + try: + unified_dt = dt.highest_precedence([ltype, rtype]) + except IbisTypeError: + errs.append(f"Unable to find a common dtype for column {name}") + errs.append(f"Left dtype: {ltype!s}") + errs.append(f"Right dtype: {rtype!s}") + if unified_dt != ltype or unified_dt != rtype: + upcasts[name] = unified_dt + if len(errs) > 1: + raise ConflictingValuesError("\n".join(errs)) + + if upcasts: + from ibis.expr.operations.generic import Cast + + def get_new_val(relation, name): + if name not in upcasts: + return Field(relation, name) + return Cast(Field(relation, name), upcasts[name]) + + lcols = {name: get_new_val(left, name) for name in left.schema.names} + rcols = {name: get_new_val(right, name) for name in left.schema.names} + left = Project(left, lcols) + right = Project(right, rcols) if left.schema.names != right.schema.names: # rewrite so that both sides have the columns in the same order making it