diff --git a/python/datafusion/context.py b/python/datafusion/context.py index a07b5d175..245093057 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -30,6 +30,7 @@ from datafusion.record_batch import RecordBatchStream from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF +import pathlib from typing import Any, TYPE_CHECKING, Protocol from typing_extensions import deprecated @@ -37,7 +38,6 @@ import pyarrow import pandas import polars - import pathlib from datafusion.plan import LogicalPlan, ExecutionPlan @@ -523,9 +523,18 @@ def register_listing_table( file_sort_order_raw, ) - def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: + def sql( + self, query: str, options: SQLOptions | None = None, **named_dfs: DataFrame + ) -> DataFrame: """Create a :py:class:`~datafusion.DataFrame` from SQL query text. + The query string can optionally take a DataFrame as a parameter by assigning + a variable inside brackets. In the following example, if we have a DataFrame + called `my_df` then the DataFrame's logical plan will be converted into an + SQL query string and inserted as a subtitution:: + + ctx.sql("SELECT name from {df}", df=my_df) + Note: This API implements DDL statements such as ``CREATE TABLE`` and ``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory default implementation.See @@ -534,12 +543,20 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Args: query: SQL query text. options: If provided, the query will be validated against these options. + named_dfs: When provided, used to replace parameterized query variables + in the query string. Returns: DataFrame representation of the SQL query. """ + if named_dfs: + for alias, df in named_dfs.items(): + df_sql = f"({df.logical_plan().to_sql()})" + query = query.replace(f"{{{alias}}}", df_sql) + if options is None: return DataFrame(self.ctx.sql(query)) + return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: @@ -753,7 +770,7 @@ def register_parquet( def register_csv( self, name: str, - path: str | pathlib.Path | list[str | pathlib.Path], + path: str | pathlib.Path | list[str] | list[pathlib.Path], schema: pyarrow.Schema | None = None, has_header: bool = True, delimiter: str = ",", @@ -917,6 +934,7 @@ def read_json( file_extension: str = ".json", table_partition_cols: list[tuple[str, str]] | None = None, file_compression_type: str | None = None, + table_name: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -929,22 +947,23 @@ def read_json( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + table_name: Name to register the table as for SQL queries Returns: DataFrame representation of the read JSON files. """ - if table_partition_cols is None: - table_partition_cols = [] - return DataFrame( - self.ctx.read_json( - str(path), - schema, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + if table_name is None: + table_name = self.generate_table_name(path) + self.register_json( + table_name, + path, + schema=schema, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + table_partition_cols=table_partition_cols, + file_compression_type=file_compression_type, ) + return self.table(table_name) def read_csv( self, @@ -956,6 +975,7 @@ def read_csv( file_extension: str = ".csv", table_partition_cols: list[tuple[str, str]] | None = None, file_compression_type: str | None = None, + table_name: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -973,27 +993,24 @@ def read_csv( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + table_name: Name to register the table as for SQL queries Returns: DataFrame representation of the read CSV files """ - if table_partition_cols is None: - table_partition_cols = [] - - path = [str(p) for p in path] if isinstance(path, list) else str(path) - - return DataFrame( - self.ctx.read_csv( - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + if table_name is None: + table_name = self.generate_table_name(path) + self.register_csv( + table_name, + path, + schema=schema, + has_header=has_header, + delimiter=delimiter, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + file_compression_type=file_compression_type, ) + return self.table(table_name) def read_parquet( self, @@ -1004,6 +1021,7 @@ def read_parquet( skip_metadata: bool = True, schema: pyarrow.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, + table_name: str | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -1021,23 +1039,24 @@ def read_parquet( the parquet reader will try to infer it based on data in the file. file_sort_order: Sort order for the file. + table_name: Name to register the table as for SQL queries Returns: DataFrame representation of the read Parquet files """ - if table_partition_cols is None: - table_partition_cols = [] - return DataFrame( - self.ctx.read_parquet( - str(path), - table_partition_cols, - parquet_pruning, - file_extension, - skip_metadata, - schema, - file_sort_order, - ) + if table_name is None: + table_name = self.generate_table_name(path) + self.register_parquet( + table_name, + path, + table_partition_cols=table_partition_cols, + parquet_pruning=parquet_pruning, + file_extension=file_extension, + skip_metadata=skip_metadata, + schema=schema, + file_sort_order=file_sort_order, ) + return self.table(table_name) def read_avro( self, @@ -1045,6 +1064,7 @@ def read_avro( schema: pyarrow.Schema | None = None, file_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".avro", + table_name: str | None = None, ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. @@ -1053,15 +1073,21 @@ def read_avro( schema: The data source schema. file_partition_cols: Partition columns. file_extension: File extension to select. + table_name: Name to register the table as for SQL queries Returns: DataFrame representation of the read Avro file """ - if file_partition_cols is None: - file_partition_cols = [] - return DataFrame( - self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) + if table_name is None: + table_name = self.generate_table_name(path) + self.register_avro( + table_name, + path, + schema=schema, + file_extension=file_extension, + table_partition_cols=file_partition_cols, ) + return self.table(table_name) def read_table(self, table: Table) -> DataFrame: """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. @@ -1075,3 +1101,22 @@ def read_table(self, table: Table) -> DataFrame: def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + + def generate_table_name( + self, path: str | pathlib.Path | list[str] | list[pathlib.Path] + ) -> str: + """Generate a table name based on the file name or a uuid.""" + import uuid + + if isinstance(path, list): + path = path[0] + + if isinstance(path, str): + path = pathlib.Path(path) + + table_name = path.stem.replace(".", "_") + + if self.table_exist(table_name): + table_name = uuid.uuid4().hex + + return table_name diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 15ad8822f..f3ee5c092 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -431,6 +431,7 @@ def window( partition_by = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) window_frame = window_frame.window_frame if window_frame is not None else None + ctx = ctx.ctx if ctx is not None else None return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx)) diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index a71965f41..bbc7e44f8 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -98,6 +98,10 @@ def to_proto(self) -> bytes: """ return self._raw_plan.to_proto() + def to_sql(self) -> str: + """Return the SQL equivalent statement for this logical plan.""" + return self._raw_plan.to_sql() + class ExecutionPlan: """Represent nodes in the DataFusion Physical Plan.""" diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index a2521dd09..f3de8a4a3 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -159,6 +159,16 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} +def test_parameterized_sql(ctx, tmp_path) -> None: + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + df = ctx.read_parquet(path) + result = ctx.sql( + "SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df + ).collect() + result = pa.Table.from_batches(result) + assert result.to_pydict() == {"cnt": [100]} + + @pytest.mark.parametrize("path_to_str", (True, False)) def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = tmp_path / "dataset_parquet_partitioned" diff --git a/src/functions.rs b/src/functions.rs index e29c57f9b..5c450286f 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::functions_aggregate::all_default_aggregate_functions; +use datafusion::functions_window::all_default_window_functions; use datafusion::logical_expr::ExprFunctionExt; use datafusion::logical_expr::WindowFrame; use pyo3::{prelude::*, wrap_pyfunction}; @@ -282,6 +283,16 @@ fn find_window_fn(name: &str, ctx: Option) -> PyResult PyResult { + plan_to_sql(&self.plan) + .map(|v| v.to_string()) + .map_err(|err| PyRuntimeError::new_err(err.to_string())) + } } impl From for LogicalPlan {