Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add to_arrow #2768

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def run_query(
num_statements: Optional[int] = None,
ignore_results: bool = False,
async_post_actions: Optional[List[Query]] = None,
*,
to_arrow: bool = False,
**kwargs,
) -> Union[Dict[str, Any], AsyncJob]:
try:
Expand Down Expand Up @@ -524,7 +526,10 @@ def run_query(
if ignore_results:
return {"data": None, "sfqid": results_cursor.sfqid}
return self._to_data_or_iter(
results_cursor=results_cursor, to_pandas=to_pandas, to_iter=to_iter
results_cursor=results_cursor,
to_pandas=to_pandas,
to_iter=to_iter,
to_arrow=to_arrow,
)
else:
return AsyncJob(
Expand All @@ -544,6 +549,8 @@ def _to_data_or_iter(
results_cursor: SnowflakeCursor,
to_pandas: bool = False,
to_iter: bool = False,
*,
to_arrow: bool = False,
) -> Dict[str, Any]:
qid = results_cursor.sfqid
if to_iter:
Expand Down Expand Up @@ -576,6 +583,8 @@ def _to_data_or_iter(
raise SnowparkClientExceptionMessages.SERVER_FAILED_FETCH_PANDAS(
str(ex)
)
elif to_arrow:
data_or_iter = results_cursor.fetch_arrow_all()
else:
data_or_iter = (
iter(results_cursor) if to_iter else results_cursor.fetchall()
Expand All @@ -592,6 +601,8 @@ def execute(
data_type: _AsyncResultType = _AsyncResultType.ROW,
log_on_exception: bool = False,
case_sensitive: bool = True,
*,
to_arrow: bool = False,
**kwargs,
) -> Union[
List[Row], "pandas.DataFrame", Iterator[Row], Iterator["pandas.DataFrame"]
Expand All @@ -615,10 +626,11 @@ def execute(
data_type=data_type,
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
to_arrow=to_arrow,
)
if not block:
return result_set
elif to_pandas:
elif to_pandas or to_arrow:
return result_set["data"]
else:
if to_iter:
Expand All @@ -641,6 +653,8 @@ def get_result_set(
log_on_exception: bool = False,
case_sensitive: bool = True,
ignore_results: bool = False,
*,
to_arrow: bool = False,
**kwargs,
) -> Tuple[
Dict[
Expand Down Expand Up @@ -698,6 +712,7 @@ def get_result_set(
params=params,
ignore_results=ignore_results,
async_post_actions=post_actions,
to_arrow=to_arrow,
**kwargs,
)

Expand Down Expand Up @@ -737,6 +752,7 @@ def get_result_set(
params=query.params,
ignore_results=ignore_results,
async_post_actions=post_actions,
to_arrow=to_arrow,
**kwargs,
)
placeholders[query.query_id_place_holder] = (
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/async_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class _AsyncResultType(Enum):
ROW = "row"
ITERATOR = "row_iterator"
PANDAS = "pandas"
ARROW = "arrow"
PANDAS_BATCH = "pandas_batches"
COUNT = "count"
NO_RESULT = "no_result"
Expand Down
71 changes: 71 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
overload,
)


import snowflake.snowpark
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -207,6 +208,7 @@

if TYPE_CHECKING:
import modin.pandas # pragma: no cover
import pyarrow
from table import Table # pragma: no cover

_logger = getLogger(__name__)
Expand Down Expand Up @@ -1284,6 +1286,75 @@ def to_snowpark_pandas(

return snowpandas_df

@publicapi
def to_arrow(
self,
*,
statement_params: Optional[Dict[str, str]] = None,
block: bool = True,
_emit_ast: bool = True,
**kwargs: Dict[str, Any],
) -> Union["pyarrow.Table", AsyncJob]:
"""
Executes the query representing this DataFrame and returns the result as a
`pandas DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_.

When the data is too large to fit into memory, you can use :meth:`to_pandas_batches`.

Args:
statement_params: Dictionary of statement level parameters to be set while executing this action.
block: A bool value indicating whether this function will wait until the result is available.
When it is ``False``, this function executes the underlying queries of the dataframe
asynchronously and returns an :class:`AsyncJob`.

Note:
1. This method is only available if pandas is installed and available.

2. If you use :func:`Session.sql` with this method, the input query of
:func:`Session.sql` can only be a SELECT statement.

3. For TIMESTAMP columns:
- TIMESTAMP_LTZ and TIMESTAMP_TZ are both converted to `datetime64[ns, tz]` in pandas,
as pandas cannot distinguish between the two.
- TIMESTAMP_NTZ is converted to `datetime64[ns]` (without timezone).
"""

if _emit_ast:
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_to_pandas, stmt)
debug_check_missing_ast(self._ast_id, self)
ast.id.bitfield1 = self._ast_id
if statement_params is not None:
build_expr_from_dict_str_str(ast.statement_params, statement_params)
ast.block = block
self._session._ast_batch.eval(stmt)

# Flush the AST and encode it as part of the query.
_, kwargs[DATAFRAME_AST_PARAMETER] = self._session._ast_batch.flush()

with open_telemetry_context_manager(self.to_pandas, self):
result = self._session._conn.execute(
self._plan,
to_arrow=True,
block=block,
data_type=_AsyncResultType.ARROW,
_statement_params=create_or_update_statement_params_with_query_tag(
statement_params or self._statement_params,
self._session.query_tag,
SKIP_LEVELS_TWO,
),
**kwargs,
)

# TODO: need the following for arrow?
# # if the returned result is not a pandas dataframe, raise Exception
# # this might happen when calling this method with non-select commands
# # e.g., session.sql("create ...").to_pandas()
# if block:
# check_is_pandas_dataframe_in_to_pandas(result)

return result

def __getitem__(self, item: Union[str, Column, List, Tuple, int]):

_emit_ast = self._ast_id is not None
Expand Down
Loading