Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1662105, SNOW-1662657: Support by, left_by, right_by for pd.merge_asof #2284

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### New Features

- Added support for `TimedeltaIndex.mean` method.
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.

## 1.22.0 (2024-09-10)

Expand Down
3 changes: 1 addition & 2 deletions docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ Data manipulations
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. |
| | | , ``left_index``, ``right_index``| |
| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. |
| | | , ``suffixes``, ``tolerance`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_ordered`` | N | | |
Expand Down
2 changes: 0 additions & 2 deletions src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def compute_bin_indices(
values_frame,
cuts_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0],
right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0],
match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame(
key,
count_frame,
"cross",
left_on=[],
right_on=[],
inherit_join_index=InheritJoinIndex.FROM_LEFT,
)

Expand Down
124 changes: 88 additions & 36 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple):
result_column_mapper: JoinOrAlignResultColumnMapper


def assert_snowpark_pandas_types_match(
left: InternalFrame,
right: InternalFrame,
left_join_identifiers: list[str],
right_join_identifiers: list[str],
) -> None:
"""
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised.

Args:
left: An internal frame to use on left side of join.
right: An internal frame to use on right side of join.
left_join_identifiers: List of snowflake identifiers to check types from 'left' frame.
right_join_identifiers: List of snowflake identifiers to check types from 'right' frame.
left_identifiers and right_identifiers must be lists of equal length.

Returns: None

Raises: ValueError
"""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_join_identifiers
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_join_identifiers
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_join_identifiers[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = (
rt
if rt is not None
else right.get_snowflake_type(right_join_identifiers[i])
)
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)


def join(
left: InternalFrame,
right: InternalFrame,
how: JoinTypeLit,
left_on: list[str],
right_on: list[str],
left_on: Optional[list[str]] = None,
right_on: Optional[list[str]] = None,
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[MatchComparator] = None,
Expand Down Expand Up @@ -161,40 +206,48 @@ def join(
include mapping for index + data columns, ordering columns and row position column
if exists.
"""
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_on
), "join_key_coalesce_config must be of same length as left_on and right_on"
assert how in get_args(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types do not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)
left_on = left_on or []
right_on = right_on or []
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"

assert_snowpark_pandas_types_match()
if how == "asof":
assert (
left_match_col
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
), "ASOF join was not provided a column identifier to match on for the left table"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"
left_join_key = [left_match_col]
right_join_key = [right_match_col]
left_join_key.extend(left_on)
right_join_key.extend(right_on)
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_join_key
), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key"
assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)
else:
left_join_key = left_on
right_join_key = right_on
assert (
left_match_col is None
and right_match_col is None
and match_comparator is None
), f"match condition should not be provided for {how} join"
if join_key_coalesce_config is not None:
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
assert len(join_key_coalesce_config) == len(
left_join_key
), "join_key_coalesce_config must be of same length as left_on and right_on"
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
Expand All @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None:
match_comparator=match_comparator,
how=how,
)

return _create_internal_frame_with_join_or_align_result(
joined_ordered_dataframe,
left,
right,
how,
left_on,
right_on,
left_join_key,
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
right_join_key,
sort,
join_key_coalesce_config,
inherit_join_index,
Expand Down Expand Up @@ -259,7 +311,6 @@ def _create_internal_frame_with_join_or_align_result(
Returns:
InternalFrame for the join/aligned result with all fields set accordingly.
"""

result_helper = JoinOrAlignOrderedDataframeResultHelper(
left.ordered_dataframe,
right.ordered_dataframe,
Expand Down Expand Up @@ -1402,6 +1453,8 @@ def _sort_on_join_keys(self) -> None:
)
elif self._how == "right":
ordering_column_identifiers = mapped_right_on
elif self._how == "asof":
ordering_column_identifiers = [mapped_left_on[0]]
else: # left join, inner join, left align, coalesce align
ordering_column_identifiers = mapped_left_on

Expand All @@ -1414,7 +1467,6 @@ def _sort_on_join_keys(self) -> None:
ordering_columns = [
OrderingColumn(key) for key in ordering_column_identifiers
] + join_or_align_result.ordering_columns

# reset the order of the ordered_dataframe to the final order
self.join_or_align_result = join_or_align_result.sort(ordering_columns)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,22 +1197,29 @@ def join(
# get the new mapped right on identifier
right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols]

# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

if how == "asof":
assert left_match_col, "left_match_col was not provided to ASOF Join"
assert (
left_match_col
), "ASOF join was not provided a column identifier to match on for the left table"
left_match_col = Column(left_match_col)
# Get the new mapped right match condition identifier
assert right_match_col, "right_match_col was not provided to ASOF Join"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
right_match_col = Column(right_identifiers_rename_map[right_match_col])
# ASOF Join requires the use of match_condition
assert match_comparator, "match_comparator was not provided to ASOF Join"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"

on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).__eq__(Column(right_col))
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right=right_snowpark_dataframe_ref.snowpark_dataframe,
on=on,
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
how=how,
match_condition=getattr(left_match_col, match_comparator.value)(
right_match_col
Expand All @@ -1224,6 +1231,12 @@ def join(
right_snowpark_dataframe_ref.snowpark_dataframe, how=how
)
else:
# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right_snowpark_dataframe_ref.snowpark_dataframe, on, how
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,6 @@ def perform_asof_join_on_frame(
left=preserving_frame,
right=referenced_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=left_timecol_snowflake_quoted_identifier,
right_match_col=right_timecol_snowflake_quoted_identifier,
match_comparator=(
Expand Down
Loading
Loading