diff --git a/docs/changelog/next_release/186.feature.rst b/docs/changelog/next_release/186.feature.rst new file mode 100644 index 00000000..f7c77466 --- /dev/null +++ b/docs/changelog/next_release/186.feature.rst @@ -0,0 +1 @@ +Add transformations for **Transfers** with dataframe column filtering \ No newline at end of file diff --git a/syncmaster/schemas/v1/transfers/__init__.py b/syncmaster/schemas/v1/transfers/__init__.py index 6ee713e1..147ae18a 100644 --- a/syncmaster/schemas/v1/transfers/__init__.py +++ b/syncmaster/schemas/v1/transfers/__init__.py @@ -29,6 +29,9 @@ S3ReadTransferTarget, ) from syncmaster.schemas.v1.transfers.strategy import FullStrategy, IncrementalStrategy +from syncmaster.schemas.v1.transfers.transformations.dataframe_columns_filter import ( + DataframeColumnsFilter, +) from syncmaster.schemas.v1.transfers.transformations.dataframe_rows_filter import ( DataframeRowsFilter, ) @@ -102,7 +105,7 @@ | None ) -TransformationSchema = DataframeRowsFilter +TransformationSchema = DataframeRowsFilter | DataframeColumnsFilter class CopyTransferSchema(BaseModel): diff --git a/syncmaster/schemas/v1/transfers/transformations/dataframe_columns_filter.py b/syncmaster/schemas/v1/transfers/transformations/dataframe_columns_filter.py new file mode 100644 index 00000000..1e7cca56 --- /dev/null +++ b/syncmaster/schemas/v1/transfers/transformations/dataframe_columns_filter.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from syncmaster.schemas.v1.transformation_types import DATAFRAME_COLUMNS_FILTER + + +class BaseColumnsFilter(BaseModel): + field: str + + +class IncludeFilter(BaseColumnsFilter): + type: Literal["include"] + + +class RenameFilter(BaseColumnsFilter): + type: Literal["rename"] + to: str + + +class CastFilter(BaseColumnsFilter): + type: Literal["cast"] + as_type: str + + +ColumnsFilter = IncludeFilter | RenameFilter | CastFilter + + +class DataframeColumnsFilter(BaseModel): + type: DATAFRAME_COLUMNS_FILTER + filters: list[Annotated[ColumnsFilter, Field(..., discriminator="type")]] = Field(default_factory=list) diff --git a/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py b/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py index b0c4914c..ef62f282 100644 --- a/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py +++ b/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py @@ -7,74 +7,74 @@ from syncmaster.schemas.v1.transformation_types import DATAFRAME_ROWS_FILTER -class BaseRowFilter(BaseModel): +class BaseRowsFilter(BaseModel): field: str -class IsNullFilter(BaseRowFilter): +class IsNullFilter(BaseRowsFilter): type: Literal["is_null"] -class IsNotNullFilter(BaseRowFilter): +class IsNotNullFilter(BaseRowsFilter): type: Literal["is_not_null"] -class EqualFilter(BaseRowFilter): +class EqualFilter(BaseRowsFilter): type: Literal["equal"] value: str -class NotEqualFilter(BaseRowFilter): +class NotEqualFilter(BaseRowsFilter): type: Literal["not_equal"] value: str -class GreaterThanFilter(BaseRowFilter): +class GreaterThanFilter(BaseRowsFilter): type: Literal["greater_than"] value: str -class GreaterOrEqualFilter(BaseRowFilter): +class GreaterOrEqualFilter(BaseRowsFilter): type: Literal["greater_or_equal"] value: str -class LessThanFilter(BaseRowFilter): +class LessThanFilter(BaseRowsFilter): type: Literal["less_than"] value: str -class LessOrEqualFilter(BaseRowFilter): +class LessOrEqualFilter(BaseRowsFilter): type: Literal["less_or_equal"] value: str -class LikeFilter(BaseRowFilter): +class LikeFilter(BaseRowsFilter): type: Literal["like"] value: str -class ILikeFilter(BaseRowFilter): +class ILikeFilter(BaseRowsFilter): type: Literal["ilike"] value: str -class NotLikeFilter(BaseRowFilter): +class NotLikeFilter(BaseRowsFilter): type: Literal["not_like"] value: str -class NotILikeFilter(BaseRowFilter): +class NotILikeFilter(BaseRowsFilter): type: Literal["not_ilike"] value: str -class RegexpFilter(BaseRowFilter): +class RegexpFilter(BaseRowsFilter): type: Literal["regexp"] value: str -RowFilter = ( +RowsFilter = ( IsNullFilter | IsNotNullFilter | EqualFilter @@ -93,4 +93,4 @@ class RegexpFilter(BaseRowFilter): class DataframeRowsFilter(BaseModel): type: DATAFRAME_ROWS_FILTER - filters: list[Annotated[RowFilter, Field(..., discriminator="type")]] = Field(default_factory=list) + filters: list[Annotated[RowsFilter, Field(..., discriminator="type")]] = Field(default_factory=list) diff --git a/syncmaster/schemas/v1/transformation_types.py b/syncmaster/schemas/v1/transformation_types.py index 9393306e..81fb56d1 100644 --- a/syncmaster/schemas/v1/transformation_types.py +++ b/syncmaster/schemas/v1/transformation_types.py @@ -3,3 +3,4 @@ from typing import Literal DATAFRAME_ROWS_FILTER = Literal["dataframe_rows_filter"] +DATAFRAME_COLUMNS_FILTER = Literal["dataframe_columns_filter"] diff --git a/syncmaster/worker/handlers/db/base.py b/syncmaster/worker/handlers/db/base.py index 9ef69a89..36e89844 100644 --- a/syncmaster/worker/handlers/db/base.py +++ b/syncmaster/worker/handlers/db/base.py @@ -38,7 +38,8 @@ def read(self) -> DataFrame: reader = DBReader( connection=self.connection, table=self.transfer_dto.table_name, - where=self._get_filter_expression(), + where=self._get_expression(transformation_type="dataframe_rows_filter"), + columns=self._get_expression(transformation_type="dataframe_columns_filter"), ) return reader.run() @@ -53,13 +54,39 @@ def write(self, df: DataFrame) -> None: def _normalize_column_names(self, df: DataFrame) -> DataFrame: ... @abstractmethod - def _make_filter_expression(self, filters: list[dict]) -> str | None: ... + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: ... - def _get_filter_expression(self) -> str | None: - filters = [] + def _make_columns_filter_expression(self, filters: list[dict]) -> list[str] | None: + expressions = [] + for filter in filters: + filter_type = filter["type"] + field = self._make_field_brackets(filter["field"]) + + if filter_type == "include": + expressions.append(field) + elif filter_type == "rename": + new_name = self._make_field_brackets(filter["to"]) + expressions.append(f"{field} AS {new_name}") + elif filter_type == "cast": + cast_type = filter["as_type"] + expressions.append(f"CAST({field} AS {cast_type}) AS {field}") + + return expressions or None + + @staticmethod + def _make_field_brackets(field: str) -> str: + return f'"{field}"' + + def _get_expression(self, transformation_type: str) -> str | None: + expressions = [] for transformation in self.transfer_dto.transformations: - if transformation["type"] == "dataframe_rows_filter": - filters.extend(transformation["filters"]) - if filters: - return self._make_filter_expression(filters) + if transformation["type"] == transformation_type: + expressions.extend(transformation["filters"]) + + if expressions: + if transformation_type == "dataframe_rows_filter": + return self._make_rows_filter_expression(expressions) + elif transformation_type == "dataframe_columns_filter": + return self._make_columns_filter_expression(expressions) + return None diff --git a/syncmaster/worker/handlers/db/clickhouse.py b/syncmaster/worker/handlers/db/clickhouse.py index 29488e16..2d0d751e 100644 --- a/syncmaster/worker/handlers/db/clickhouse.py +++ b/syncmaster/worker/handlers/db/clickhouse.py @@ -64,10 +64,10 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: - field = f'"{filter["field"]}"' + field = self._make_field_brackets(filter["field"]) op = self._operators[filter["type"]] value = filter.get("value") diff --git a/syncmaster/worker/handlers/db/hive.py b/syncmaster/worker/handlers/db/hive.py index 440d3394..b951f2de 100644 --- a/syncmaster/worker/handlers/db/hive.py +++ b/syncmaster/worker/handlers/db/hive.py @@ -40,11 +40,11 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: op = self._operators[filter["type"]] - field = f"`{filter["field"]}`" + field = self._make_field_brackets(filter["field"]) value = filter.get("value") if value is None: @@ -59,3 +59,7 @@ def _make_filter_expression(self, filters: list[dict]) -> str | None: expressions.append(f"{field} {op} '{value}'") return " AND ".join(expressions) or None + + @staticmethod + def _make_field_brackets(field): + return f"`{field}`" diff --git a/syncmaster/worker/handlers/db/mssql.py b/syncmaster/worker/handlers/db/mssql.py index bbe3a5c2..29f85cd5 100644 --- a/syncmaster/worker/handlers/db/mssql.py +++ b/syncmaster/worker/handlers/db/mssql.py @@ -43,11 +43,11 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: op = self._operators[filter["type"]] - field = f'"{filter["field"]}"' + field = self._make_field_brackets(filter["field"]) value = filter.get("value") if value is None: diff --git a/syncmaster/worker/handlers/db/mysql.py b/syncmaster/worker/handlers/db/mysql.py index b62a5ca6..b56628e5 100644 --- a/syncmaster/worker/handlers/db/mysql.py +++ b/syncmaster/worker/handlers/db/mysql.py @@ -40,11 +40,11 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: op = self._operators[filter["type"]] - field = f"`{filter["field"]}`" + field = self._make_field_brackets(filter["field"]) value = filter.get("value") if value is None: @@ -59,3 +59,7 @@ def _make_filter_expression(self, filters: list[dict]) -> str | None: expressions.append(f"{field} {op} '{value}'") return " AND ".join(expressions) or None + + @staticmethod + def _make_field_brackets(field: str) -> str: + return f"`{field}`" diff --git a/syncmaster/worker/handlers/db/oracle.py b/syncmaster/worker/handlers/db/oracle.py index d8a87224..9a10631f 100644 --- a/syncmaster/worker/handlers/db/oracle.py +++ b/syncmaster/worker/handlers/db/oracle.py @@ -42,10 +42,10 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.upper()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: - field = f'"{filter["field"]}"' + field = self._make_field_brackets(filter["field"]) op = self._operators[filter["type"]] value = filter.get("value") diff --git a/syncmaster/worker/handlers/db/postgres.py b/syncmaster/worker/handlers/db/postgres.py index e5fcfc60..f17c4712 100644 --- a/syncmaster/worker/handlers/db/postgres.py +++ b/syncmaster/worker/handlers/db/postgres.py @@ -41,10 +41,10 @@ def _normalize_column_names(self, df: DataFrame) -> DataFrame: df = df.withColumnRenamed(column_name, column_name.lower()) return df - def _make_filter_expression(self, filters: list[dict]) -> str | None: + def _make_rows_filter_expression(self, filters: list[dict]) -> str | None: expressions = [] for filter in filters: - field = f'"{filter["field"]}"' + field = self._make_field_brackets(filter["field"]) op = self._operators[filter["type"]] value = filter.get("value") diff --git a/syncmaster/worker/handlers/file/base.py b/syncmaster/worker/handlers/file/base.py index 08a85cf2..e2d51aee 100644 --- a/syncmaster/worker/handlers/file/base.py +++ b/syncmaster/worker/handlers/file/base.py @@ -48,9 +48,13 @@ def read(self) -> DataFrame: ) df = reader.run() - filter_expression = self._get_filter_expression() - if filter_expression: - df = df.where(filter_expression) + row_filter_expression = self._get_expression(transformation_type="dataframe_rows_filter") + if row_filter_expression: + df = df.where(row_filter_expression) + + column_filter_expression = self._get_expression(transformation_type="dataframe_columns_filter") + if column_filter_expression: + df = df.selectExpr(*column_filter_expression) return df @@ -64,16 +68,21 @@ def write(self, df: DataFrame): return writer.run(df=df) - def _get_filter_expression(self) -> str | None: - filters = [] + def _get_expression(self, transformation_type: str) -> str | None: + expressions = [] for transformation in self.transfer_dto.transformations: - if transformation["type"] == "dataframe_rows_filter": - filters.extend(transformation["filters"]) - if filters: - return self._make_filter_expression(filters) + if transformation["type"] == transformation_type: + expressions.extend(transformation["filters"]) + + if expressions: + if transformation_type == "dataframe_rows_filter": + return self._make_rows_filter_expression(expressions) + elif transformation_type == "dataframe_columns_filter": + return self._make_columns_filter_expression(expressions) + return None - def _make_filter_expression(self, filters: list[dict]) -> str: + def _make_rows_filter_expression(self, filters: list[dict]) -> str: expressions = [] for filter in filters: field = filter["field"] @@ -83,3 +92,20 @@ def _make_filter_expression(self, filters: list[dict]) -> str: expressions.append(f"{field} {op} '{value}'" if value is not None else f"{field} {op}") return " AND ".join(expressions) + + def _make_columns_filter_expression(self, filters: list[dict]) -> list[str] | None: + expressions = [] + for filter in filters: + filter_type = filter["type"] + field = filter["field"] + + if filter_type == "include": + expressions.append(field) + elif filter_type == "rename": + new_name = filter["to"] + expressions.append(f"{field} AS {new_name}") + elif filter_type == "cast": + cast_type = filter["as_type"] + expressions.append(f"CAST({field} AS {cast_type}) AS {field}") + + return expressions or None diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index 6cf4c4e9..a1249ccf 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -47,8 +47,12 @@ def read(self) -> DataFrame: ) df = reader.run() - filter_expression = self._get_filter_expression() - if filter_expression: - df = df.where(filter_expression) + row_filter_expression = self._get_expression(transformation_type="dataframe_rows_filter") + if row_filter_expression: + df = df.where(row_filter_expression) + + column_filter_expression = self._get_expression(transformation_type="dataframe_columns_filter") + if column_filter_expression: + df = df.selectExpr(*column_filter_expression) return df diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 5ab914e4..25338426 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -1317,7 +1317,8 @@ def init_df_with_mixed_column_naming(spark: SparkSession) -> DataFrame: @pytest.fixture -def dataframe_rows_filter_transformations(): +def dataframe_rows_filter_transformations(source_type: str): + regexp_value = "[0-9!@#$.,;_]%" if source_type == "mssql" else "^[^+].*" # MSSQL regexp is limited return [ { "type": "dataframe_rows_filter", @@ -1344,7 +1345,7 @@ def dataframe_rows_filter_transformations(): { "type": "regexp", "field": "PHONE_NUMBER", - "value": "^[^+].*", + "value": regexp_value, }, ], }, @@ -1353,10 +1354,69 @@ def dataframe_rows_filter_transformations(): @pytest.fixture def expected_dataframe_rows_filter(): - return lambda df: ( + return lambda df, source_type: df.filter( df["BIRTH_DATE"].isNotNull() & (df["NUMBER"] <= "25") & (~df["REGION"].like("%port")) & (~df["REGION"].ilike("new%")) - & (df["PHONE_NUMBER"].rlike("^[^+].*")) + & (df["PHONE_NUMBER"].rlike("^[0-9!@#$.,;_]" if source_type == "mssql" else "^[^+].*")), + ) + + +@pytest.fixture +def dataframe_columns_filter_transformations(source_type: str): + as_type_map = { + "postgres": "VARCHAR(10)", + "oracle": "VARCHAR2(10)", + "clickhouse": "VARCHAR(10)", + "mysql": "CHAR", + "mssql": "VARCHAR(10)", + "hive": "VARCHAR(10)", + "s3": "STRING", + "hdfs": "STRING", + } + return [ + { + "type": "dataframe_columns_filter", + "filters": [ + { + "type": "include", + "field": "ID", + }, + { + "type": "include", + "field": "REGION", + }, + { + "type": "include", + "field": "PHONE_NUMBER", + }, + { + "type": "rename", + "field": "REGION", + "to": "REGION2", + }, + { + "type": "cast", + "field": "NUMBER", + "as_type": as_type_map[source_type], + }, + { + "type": "include", + "field": "REGISTERED_AT", + }, + ], + }, + ] + + +@pytest.fixture +def expected_dataframe_columns_filter(): + return lambda df, source_type: df.selectExpr( + "ID", + "REGION", + "PHONE_NUMBER", + "REGION AS REGION2", + "CAST(NUMBER AS STRING) AS NUMBER", + "REGISTERED_AT", ) diff --git a/tests/test_integration/test_run_transfer/test_clickhouse.py b/tests/test_integration/test_run_transfer/test_clickhouse.py index bd18f779..4378cd4f 100644 --- a/tests/test_integration/test_run_transfer/test_clickhouse.py +++ b/tests/test_integration/test_run_transfer/test_clickhouse.py @@ -82,12 +82,18 @@ async def clickhouse_to_postgres( @pytest.mark.parametrize( - "transformations, expected_filter", + "source_type, transformations, expected_filter", [ ( + "clickhouse", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), ), + ( + "clickhouse", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), + ), ], ) async def test_run_transfer_postgres_to_clickhouse( @@ -97,6 +103,7 @@ async def test_run_transfer_postgres_to_clickhouse( prepare_clickhouse, init_df: DataFrame, postgres_to_clickhouse: Transfer, + source_type, transformations, expected_filter, ): @@ -104,7 +111,7 @@ async def test_run_transfer_postgres_to_clickhouse( _, fill_with_data = prepare_postgres fill_with_data(init_df) clickhouse, _ = prepare_clickhouse - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( @@ -192,12 +199,18 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming( @pytest.mark.parametrize( - "transformations, expected_filter", + "source_type, transformations, expected_filter", [ ( + "clickhouse", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), ), + ( + "clickhouse", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), + ), ], ) async def test_run_transfer_clickhouse_to_postgres( @@ -206,6 +219,7 @@ async def test_run_transfer_clickhouse_to_postgres( prepare_clickhouse, prepare_postgres, init_df: DataFrame, + source_type, transformations, expected_filter, clickhouse_to_postgres: Transfer, @@ -214,7 +228,7 @@ async def test_run_transfer_clickhouse_to_postgres( _, fill_with_data = prepare_clickhouse fill_with_data(init_df) postgres, _ = prepare_postgres - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_integration/test_run_transfer/test_hive.py b/tests/test_integration/test_run_transfer/test_hive.py index 2a25ea1c..497c4d56 100644 --- a/tests/test_integration/test_run_transfer/test_hive.py +++ b/tests/test_integration/test_run_transfer/test_hive.py @@ -183,9 +183,15 @@ async def test_run_transfer_postgres_to_hive_mixed_naming( "transformations, expected_filter", [ ( + "hive", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), ), + ( + "hive", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), + ), ], ) async def test_run_transfer_hive_to_postgres( @@ -195,6 +201,7 @@ async def test_run_transfer_hive_to_postgres( prepare_postgres, init_df: DataFrame, hive_to_postgres: Transfer, + source_type, transformations, expected_filter, ): @@ -202,7 +209,7 @@ async def test_run_transfer_hive_to_postgres( _, fill_with_data = prepare_hive fill_with_data(init_df) postgres, _ = prepare_postgres - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_integration/test_run_transfer/test_mssql.py b/tests/test_integration/test_run_transfer/test_mssql.py index 8d845987..5a83c08b 100644 --- a/tests/test_integration/test_run_transfer/test_mssql.py +++ b/tests/test_integration/test_run_transfer/test_mssql.py @@ -7,6 +7,7 @@ from onetl.db import DBReader from pyspark.sql import DataFrame from pyspark.sql.functions import col, date_trunc +from pytest_lazy_fixtures import lf from sqlalchemy.ext.asyncio import AsyncSession from syncmaster.db.models import Connection, Group, Queue, Status, Transfer @@ -195,48 +196,17 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming( @pytest.mark.parametrize( - "transformations, expected_filter", + "source_type, transformations, expected_filter", [ ( - [ - { - "type": "dataframe_rows_filter", - "filters": [ - { - "type": "is_not_null", - "field": "BIRTH_DATE", - }, - { - "type": "less_or_equal", - "field": "NUMBER", - "value": "25", - }, - { - "type": "not_like", - "field": "REGION", - "value": "%port", - }, - { - "type": "not_ilike", - "field": "REGION", - "value": "new%", - }, - { - "type": "regexp", - "field": "PHONE_NUMBER", - "value": "^[0-9!@#$.,;_]%", - # available expressions are limited - }, - ], - }, - ], - lambda df: ( - df["BIRTH_DATE"].isNotNull() - & (df["NUMBER"] <= "25") - & (~df["REGION"].like("%port")) - & (~df["REGION"].ilike("new%")) - & (df["PHONE_NUMBER"].rlike("[0-9!@#$.,;_]%")) - ), + "mssql", + lf("dataframe_rows_filter_transformations"), + lf("expected_dataframe_rows_filter"), + ), + ( + "mssql", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), ), ], ) @@ -247,6 +217,7 @@ async def test_run_transfer_mssql_to_postgres( prepare_postgres, init_df: DataFrame, mssql_to_postgres: Transfer, + source_type, transformations, expected_filter, ): @@ -254,7 +225,7 @@ async def test_run_transfer_mssql_to_postgres( _, fill_with_data = prepare_mssql fill_with_data(init_df) postgres, _ = prepare_postgres - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_integration/test_run_transfer/test_mysql.py b/tests/test_integration/test_run_transfer/test_mysql.py index fca399ef..a9b52cf1 100644 --- a/tests/test_integration/test_run_transfer/test_mysql.py +++ b/tests/test_integration/test_run_transfer/test_mysql.py @@ -203,12 +203,18 @@ async def test_run_transfer_postgres_to_mysql_mixed_naming( @pytest.mark.parametrize( - "transformations, expected_filter", + "source_type, transformations, expected_filter", [ ( + "mysql", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), ), + ( + "mysql", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), + ), ], ) async def test_run_transfer_mysql_to_postgres( @@ -218,6 +224,7 @@ async def test_run_transfer_mysql_to_postgres( prepare_postgres, init_df: DataFrame, mysql_to_postgres: Transfer, + source_type, transformations, expected_filter, ): @@ -225,7 +232,7 @@ async def test_run_transfer_mysql_to_postgres( _, fill_with_data = prepare_mysql fill_with_data(init_df) postgres, _ = prepare_postgres - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_integration/test_run_transfer/test_oracle.py b/tests/test_integration/test_run_transfer/test_oracle.py index aae80d6f..049513cd 100644 --- a/tests/test_integration/test_run_transfer/test_oracle.py +++ b/tests/test_integration/test_run_transfer/test_oracle.py @@ -182,12 +182,18 @@ async def test_run_transfer_postgres_to_oracle_mixed_naming( @pytest.mark.parametrize( - "transformations, expected_filter", + "source_type, transformations, expected_filter", [ ( + "oracle", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), ), + ( + "oracle", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), + ), ], ) async def test_run_transfer_oracle_to_postgres( @@ -197,6 +203,7 @@ async def test_run_transfer_oracle_to_postgres( prepare_postgres, init_df: DataFrame, oracle_to_postgres: Transfer, + source_type, transformations, expected_filter, ): @@ -204,7 +211,7 @@ async def test_run_transfer_oracle_to_postgres( _, fill_with_data = prepare_oracle fill_with_data(init_df) postgres, _ = prepare_postgres - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_integration/test_run_transfer/test_s3.py b/tests/test_integration/test_run_transfer/test_s3.py index 8c6fde45..15285e25 100644 --- a/tests/test_integration/test_run_transfer/test_s3.py +++ b/tests/test_integration/test_run_transfer/test_s3.py @@ -112,24 +112,28 @@ async def postgres_to_s3( @pytest.mark.parametrize( - "source_file_format, file_format_flavor, transformations, expected_filter", + "source_file_format, file_format_flavor, source_type, transformations, expected_filter", [ pytest.param( ("csv", {}), "with_header", + "s3", lf("dataframe_rows_filter_transformations"), lf("expected_dataframe_rows_filter"), + id="csv", ), pytest.param( ("json", {}), "without_compression", - [], - None, + "s3", + lf("dataframe_columns_filter_transformations"), + lf("expected_dataframe_columns_filter"), id="json", ), pytest.param( ("jsonline", {}), "without_compression", + None, [], None, id="jsonline", @@ -137,6 +141,7 @@ async def postgres_to_s3( pytest.param( ("excel", {}), "with_header", + None, [], None, id="excel", @@ -144,6 +149,7 @@ async def postgres_to_s3( pytest.param( ("orc", {}), "without_compression", + None, [], None, id="orc", @@ -151,6 +157,7 @@ async def postgres_to_s3( pytest.param( ("parquet", {}), "without_compression", + None, [], None, id="parquet", @@ -158,6 +165,7 @@ async def postgres_to_s3( pytest.param( ("xml", {}), "without_compression", + None, [], None, id="xml", @@ -173,6 +181,7 @@ async def test_run_transfer_s3_to_postgres( s3_to_postgres: Transfer, source_file_format, file_format_flavor, + source_type, transformations, expected_filter, ): @@ -180,7 +189,7 @@ async def test_run_transfer_s3_to_postgres( postgres, _ = prepare_postgres file_format, _ = source_file_format if expected_filter: - init_df = init_df.where(expected_filter(init_df)) + init_df = expected_filter(init_df, source_type) # Act result = await client.post( diff --git a/tests/test_unit/test_transfers/test_create_transfer.py b/tests/test_unit/test_transfers/test_create_transfer.py index 3c005246..770bdca0 100644 --- a/tests/test_unit/test_transfers/test_create_transfer.py +++ b/tests/test_unit/test_transfers/test_create_transfer.py @@ -52,6 +52,25 @@ async def test_developer_plus_can_create_transfer( }, ], }, + { + "type": "dataframe_columns_filter", + "filters": [ + { + "type": "include", + "field": "col1", + }, + { + "type": "rename", + "field": "col2", + "to": "new_col2", + }, + { + "type": "cast", + "field": "col3", + "as_type": "VARCHAR", + }, + ], + }, ], "queue_id": group_queue.id, }, @@ -447,12 +466,12 @@ async def test_superuser_can_create_transfer( "location": ["body", "transformations", 0], "message": ( "Input tag 'some unknown transformation type' found using 'type' " - "does not match any of the expected tags: 'dataframe_rows_filter'" + "does not match any of the expected tags: 'dataframe_rows_filter', 'dataframe_columns_filter'" ), "code": "union_tag_invalid", "context": { "discriminator": "'type'", - "expected_tags": "'dataframe_rows_filter'", + "expected_tags": "'dataframe_rows_filter', 'dataframe_columns_filter'", "tag": "some unknown transformation type", }, "input": { @@ -512,6 +531,47 @@ async def test_superuser_can_create_transfer( }, }, ), + ( + { + "transformations": [ + { + "type": "dataframe_columns_filter", + "filters": [ + { + "type": "convert", + "field": "col1", + "value": "VARCHAR", + }, + ], + }, + ], + }, + { + "error": { + "code": "invalid_request", + "message": "Invalid request", + "details": [ + { + "location": ["body", "transformations", 0, "dataframe_columns_filter", "filters", 0], + "message": ( + "Input tag 'convert' found using 'type' does not match any of the expected tags: 'include', 'rename', 'cast'" + ), + "code": "union_tag_invalid", + "context": { + "discriminator": "'type'", + "tag": "convert", + "expected_tags": "'include', 'rename', 'cast'", + }, + "input": { + "type": "convert", + "field": "col1", + "value": "VARCHAR", + }, + }, + ], + }, + }, + ), ), ) async def test_check_fields_validation_on_create_transfer( diff --git a/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py b/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py index be09ae69..da8cd82c 100644 --- a/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py +++ b/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py @@ -88,6 +88,15 @@ }, ], }, + { + "type": "dataframe_columns_filter", + "filters": [ + { + "type": "include", + "field": "col1", + }, + ], + }, ], }, ],