From 887fbb2f262951f46d62a632755d3b71be8ee3de Mon Sep 17 00:00:00 2001 From: Ralph Rassweiler Date: Wed, 29 May 2024 16:32:14 -0300 Subject: [PATCH] feat(mlop-2269): bump versions (#355) * fix: bump versions adjust tests * add checklist * chore: bump python * bump pyspark * chore: java version all steps modified --- .checklist.yaml | 30 ++++ .github/workflows/publish.yml | 13 ++ .github/workflows/staging.yml | 16 +- .github/workflows/test.yml | 16 +- .gitignore | 1 + Makefile | 6 +- butterfree/_cli/migrate.py | 12 +- butterfree/clients/cassandra_client.py | 4 +- butterfree/clients/spark_client.py | 6 +- butterfree/extract/source.py | 5 +- .../historical_feature_store_writer.py | 5 +- .../writers/online_feature_store_writer.py | 10 +- butterfree/load/writers/writer.py | 5 +- .../database_migration/database_migration.py | 5 +- .../database_migration/metastore_migration.py | 5 +- .../transform/aggregated_feature_set.py | 4 +- .../transformations/aggregated_transform.py | 6 +- .../transformations/custom_transform.py | 4 +- .../transform/transformations/h3_transform.py | 5 +- .../sql_expression_transform.py | 3 +- docs/requirements.txt | 3 +- examples/test_examples.py | 4 +- mypy.ini | 41 +++++- requirements.dev.txt | 10 +- requirements.lint.txt | 11 +- requirements.test.txt | 2 +- requirements.txt | 6 +- setup.cfg | 2 +- setup.py | 2 +- .../butterfree/extract/test_source.py | 13 +- tests/integration/butterfree/load/conftest.py | 2 +- .../integration/butterfree/load/test_sink.py | 7 +- .../butterfree/pipelines/conftest.py | 3 +- .../pipelines/test_feature_set_pipeline.py | 72 ++++++--- .../transform/test_aggregated_feature_set.py | 16 +- .../butterfree/transform/test_feature_set.py | 10 +- tests/mocks/entities/first/first_pipeline.py | 18 ++- .../entities/second/deeper/second_pipeline.py | 16 +- .../butterfree/clients/test_spark_client.py | 14 +- .../pre_processing/test_filter_transform.py | 3 +- .../pre_processing/test_pivot_transform.py | 36 ++++- .../extract/readers/test_file_reader.py | 10 +- .../butterfree/extract/readers/test_reader.py | 3 +- .../extract/readers/test_table_reader.py | 9 +- tests/unit/butterfree/extract/test_source.py | 6 +- tests/unit/butterfree/load/conftest.py | 6 +- .../load/processing/test_json_transform.py | 4 +- .../migrations/database_migration/conftest.py | 12 +- tests/unit/butterfree/pipelines/conftest.py | 13 +- .../pipelines/test_feature_set_pipeline.py | 38 ++++- .../unit/butterfree/reports/test_metadata.py | 139 +++++++----------- tests/unit/butterfree/transform/conftest.py | 8 +- .../transform/features/test_feature.py | 4 +- .../transform/test_aggregated_feature_set.py | 14 +- .../butterfree/transform/test_feature_set.py | 21 ++- .../transform/transformations/conftest.py | 2 +- .../test_aggregated_transform.py | 5 +- .../transformations/test_custom_transform.py | 12 +- .../transformations/test_h3_transform.py | 6 +- .../test_spark_function_transform.py | 4 +- .../test_sql_expression_transform.py | 10 +- 61 files changed, 547 insertions(+), 231 deletions(-) create mode 100644 .checklist.yaml diff --git a/.checklist.yaml b/.checklist.yaml new file mode 100644 index 000000000..f0c211714 --- /dev/null +++ b/.checklist.yaml @@ -0,0 +1,30 @@ +apiVersion: quintoandar.com.br/checklist/v2 +kind: ServiceChecklist +metadata: + name: butterfree +spec: + description: >- + A solution for Feature Stores. + + costCenter: C055 + department: engineering + lifecycle: production + docs: true + + ownership: + team: data_products_mlops + line: tech_platform + owner: otavio.cals@quintoandar.com.br + + libraries: + - name: butterfree + type: common-usage + path: https://quintoandar.github.io/python-package-server/ + description: A lib to build Feature Stores. + registries: + - github-packages + tier: T0 + + channels: + squad: 'mlops' + alerts: 'data-products-reports' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f981921e6..0957a958a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -14,6 +14,19 @@ jobs: steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' - name: Install dependencies run: make ci-install diff --git a/.github/workflows/staging.yml b/.github/workflows/staging.yml index 77127820e..573049cac 100644 --- a/.github/workflows/staging.yml +++ b/.github/workflows/staging.yml @@ -8,11 +8,23 @@ jobs: Pipeline: if: github.ref == 'refs/heads/staging' - runs-on: ubuntu-22.04 - container: quintoandar/python-3-7-java + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' - name: Install dependencies run: make ci-install diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d7c1c3acc..d588c8533 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,11 +9,23 @@ on: jobs: Pipeline: - runs-on: ubuntu-22.04 - container: quintoandar/python-3-7-java + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: microsoft + + - uses: vemonet/setup-spark@v1 + with: + spark-version: '3.5.1' + hadoop-version: '3' - name: Install dependencies run: make ci-install diff --git a/.gitignore b/.gitignore index 62434612f..0c59b49ab 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ instance/ # PyBuilder target/ +pip/ # Jupyter Notebook .ipynb_checkpoints diff --git a/Makefile b/Makefile index 4109504f6..ba0d0ead4 100644 --- a/Makefile +++ b/Makefile @@ -76,7 +76,7 @@ style-check: @echo "Code Style" @echo "==========" @echo "" - @python -m black --check -t py36 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1) + @python -m black --check -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1) .PHONY: quality-check ## run code quality checks with flake8 @@ -104,8 +104,8 @@ checks: style-check quality-check type-check .PHONY: apply-style ## fix stylistic errors with black apply-style: - @python -m black -t py36 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . - @python -m isort -rc --atomic butterfree/ tests/ + @python -m black -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . + @python -m isort --atomic butterfree/ tests/ .PHONY: clean ## clean unused artifacts diff --git a/butterfree/_cli/migrate.py b/butterfree/_cli/migrate.py index 277ecf3c6..ed62f1a24 100644 --- a/butterfree/_cli/migrate.py +++ b/butterfree/_cli/migrate.py @@ -46,13 +46,13 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]: logger.error(f"Path: {path} not found!") return set() - logger.info(f"Importing modules...") + logger.info("Importing modules...") package = ".".join(path.strip("/").split("/")) imported = set( importlib.import_module(f".{name}", package=package) for name in modules ) - logger.info(f"Scanning modules...") + logger.info("Scanning modules...") content = { module: set( filter( @@ -93,7 +93,8 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]: PATH = typer.Argument( - ..., help="Full or relative path to where feature set pipelines are being defined.", + ..., + help="Full or relative path to where feature set pipelines are being defined.", ) GENERATE_LOGS = typer.Option( @@ -113,7 +114,10 @@ class Migrate: pipelines: list of Feature Set Pipelines to use to migration. """ - def __init__(self, pipelines: Set[FeatureSetPipeline],) -> None: + def __init__( + self, + pipelines: Set[FeatureSetPipeline], + ) -> None: self.pipelines = pipelines def _send_logs_to_s3(self, file_local: bool, debug_mode: bool) -> None: diff --git a/butterfree/clients/cassandra_client.py b/butterfree/clients/cassandra_client.py index 4c6f96fe0..5a7231555 100644 --- a/butterfree/clients/cassandra_client.py +++ b/butterfree/clients/cassandra_client.py @@ -129,7 +129,9 @@ def get_schema(self, table: str, database: str = None) -> List[Dict[str, str]]: return response def _get_create_table_query( - self, columns: List[CassandraColumn], table: str, + self, + columns: List[CassandraColumn], + table: str, ) -> str: """Creates CQL statement to create a table.""" parsed_columns = [] diff --git a/butterfree/clients/spark_client.py b/butterfree/clients/spark_client.py index bfa31d2a3..e2b868caf 100644 --- a/butterfree/clients/spark_client.py +++ b/butterfree/clients/spark_client.py @@ -61,9 +61,9 @@ def read( if path and not isinstance(path, (str, list)): raise ValueError("path needs to be a string or a list of string") - df_reader: Union[ - DataStreamReader, DataFrameReader - ] = self.conn.readStream if stream else self.conn.read + df_reader: Union[DataStreamReader, DataFrameReader] = ( + self.conn.readStream if stream else self.conn.read + ) df_reader = df_reader.schema(schema) if schema else df_reader diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py index 1209e9162..281ed15ad 100644 --- a/butterfree/extract/source.py +++ b/butterfree/extract/source.py @@ -58,7 +58,10 @@ class Source(HookableComponent): """ def __init__( - self, readers: List[Reader], query: str, eager_evaluation: bool = True, + self, + readers: List[Reader], + query: str, + eager_evaluation: bool = True, ) -> None: super().__init__() self.enable_pre_hooks = False diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py index 0ea9b50c8..1a64afdf3 100644 --- a/butterfree/load/writers/historical_feature_store_writer.py +++ b/butterfree/load/writers/historical_feature_store_writer.py @@ -130,7 +130,10 @@ def __init__( self.check_schema_hook = check_schema_hook def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> None: """Loads the data from a feature set into the Historical Feature Store. diff --git a/butterfree/load/writers/online_feature_store_writer.py b/butterfree/load/writers/online_feature_store_writer.py index 17dc8af4b..d0bcde948 100644 --- a/butterfree/load/writers/online_feature_store_writer.py +++ b/butterfree/load/writers/online_feature_store_writer.py @@ -116,7 +116,10 @@ def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame: window = Window.partitionBy(*id_columns).orderBy(col(TIMESTAMP_COLUMN).desc()) return ( - dataframe.select(col("*"), row_number().over(window).alias("rn"),) + dataframe.select( + col("*"), + row_number().over(window).alias("rn"), + ) .filter(col("rn") == 1) .drop("rn") ) @@ -162,7 +165,10 @@ def _write_in_debug_mode( ) def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> Union[StreamingQuery, None]: """Loads the latest data from a feature set into the Feature Store. diff --git a/butterfree/load/writers/writer.py b/butterfree/load/writers/writer.py index 5073f4726..1dae795c6 100644 --- a/butterfree/load/writers/writer.py +++ b/butterfree/load/writers/writer.py @@ -72,7 +72,10 @@ def _apply_transformations(self, df: DataFrame) -> DataFrame: @abstractmethod def write( - self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, + self, + feature_set: FeatureSet, + dataframe: DataFrame, + spark_client: SparkClient, ) -> Any: """Loads the data from a feature set into the Feature Store. diff --git a/butterfree/migrations/database_migration/database_migration.py b/butterfree/migrations/database_migration/database_migration.py index aeec4a6e7..468c028ec 100644 --- a/butterfree/migrations/database_migration/database_migration.py +++ b/butterfree/migrations/database_migration/database_migration.py @@ -180,7 +180,8 @@ def create_query( @staticmethod def _get_diff( - fs_schema: List[Dict[str, Any]], db_schema: List[Dict[str, Any]], + fs_schema: List[Dict[str, Any]], + db_schema: List[Dict[str, Any]], ) -> Set[Diff]: """Gets schema difference between feature set and the table of a given db. @@ -296,7 +297,7 @@ def apply_migration( logger.info(f"Applying this query: {q} ...") self._client.sql(q) - logger.info(f"Feature Set migration finished successfully.") + logger.info("Feature Set migration finished successfully.") # inform in drone console which feature set was migrated print(f"The {feature_set.name} feature set was migrated.") diff --git a/butterfree/migrations/database_migration/metastore_migration.py b/butterfree/migrations/database_migration/metastore_migration.py index daa0afd3d..8c6c211ae 100644 --- a/butterfree/migrations/database_migration/metastore_migration.py +++ b/butterfree/migrations/database_migration/metastore_migration.py @@ -30,7 +30,10 @@ class MetastoreMigration(DatabaseMigration): data is being loaded into an entity table, then users can drop columns manually. """ - def __init__(self, database: str = None,) -> None: + def __init__( + self, + database: str = None, + ) -> None: self._db_config = MetastoreConfig() self.database = database or environment.get_variable( "FEATURE_STORE_HISTORICAL_DATABASE" diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py index 0bff33c65..c86a95c3d 100644 --- a/butterfree/transform/aggregated_feature_set.py +++ b/butterfree/transform/aggregated_feature_set.py @@ -412,7 +412,9 @@ def _aggregate( # repartition to have all rows for each group at the same partition # by doing that, we won't have to shuffle data on grouping by id dataframe = repartition_df( - dataframe, partition_by=groupby, num_processors=num_processors, + dataframe, + partition_by=groupby, + num_processors=num_processors, ) grouped_data = dataframe.groupby(*groupby) diff --git a/butterfree/transform/transformations/aggregated_transform.py b/butterfree/transform/transformations/aggregated_transform.py index 7304f34b6..a9581ef00 100644 --- a/butterfree/transform/transformations/aggregated_transform.py +++ b/butterfree/transform/transformations/aggregated_transform.py @@ -76,7 +76,11 @@ def aggregations(self) -> List[Tuple]: Function = namedtuple("Function", ["function", "data_type"]) return [ - Function(f.func(expression), f.data_type.spark,) for f in self.functions + Function( + f.func(expression), + f.data_type.spark, + ) + for f in self.functions ] def _get_output_name(self, function: object) -> str: diff --git a/butterfree/transform/transformations/custom_transform.py b/butterfree/transform/transformations/custom_transform.py index 9b5ae23b1..7860fdc20 100644 --- a/butterfree/transform/transformations/custom_transform.py +++ b/butterfree/transform/transformations/custom_transform.py @@ -89,6 +89,8 @@ def transform(self, dataframe: DataFrame) -> DataFrame: """ dataframe = self.transformer( - dataframe, self.parent, **self.transformer__kwargs, + dataframe, + self.parent, + **self.transformer__kwargs, ) return dataframe diff --git a/butterfree/transform/transformations/h3_transform.py b/butterfree/transform/transformations/h3_transform.py index 8ccd3bb38..7a98294ec 100644 --- a/butterfree/transform/transformations/h3_transform.py +++ b/butterfree/transform/transformations/h3_transform.py @@ -84,7 +84,10 @@ class H3HashTransform(TransformComponent): """ def __init__( - self, h3_resolutions: List[int], lat_column: str, lng_column: str, + self, + h3_resolutions: List[int], + lat_column: str, + lng_column: str, ): super().__init__() self.h3_resolutions = h3_resolutions diff --git a/butterfree/transform/transformations/sql_expression_transform.py b/butterfree/transform/transformations/sql_expression_transform.py index 0199c23ae..80cd41ea9 100644 --- a/butterfree/transform/transformations/sql_expression_transform.py +++ b/butterfree/transform/transformations/sql_expression_transform.py @@ -54,7 +54,8 @@ class SQLExpressionTransform(TransformComponent): """ def __init__( - self, expression: str, + self, + expression: str, ): super().__init__() self.expression = expression diff --git a/docs/requirements.txt b/docs/requirements.txt index a20ab18ff..7eaabf11a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,5 +4,4 @@ sphinxemoji==0.1.6 typing-extensions==3.7.4.2 cmake==3.18.4 h3==3.7.0 -pyarrow==0.15.1 - +pyarrow==16.1.0 diff --git a/examples/test_examples.py b/examples/test_examples.py index b40b6e1a4..7180e080d 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -36,9 +36,9 @@ _, error = p.communicate() if p.returncode != 0: errors.append({"notebook": path, "error": error}) - print(f" >>> Error in execution!\n") + print(" >>> Error in execution!\n") else: - print(f" >>> Successful execution\n") + print(" >>> Successful execution\n") if errors: print(">>> Errors in the following notebooks:") diff --git a/mypy.ini b/mypy.ini index c67bd3a89..fc2931493 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.7 +python_version = 3.9 ignore_missing_imports = True disallow_untyped_calls = False disallow_untyped_defs = True @@ -9,3 +9,42 @@ show_error_codes = True show_error_context = True disable_error_code = attr-defined, list-item, operator pretty = True + +[mypy-butterfree.pipelines.*] +ignore_errors = True + +[mypy-butterfree.load.*] +ignore_errors = True + +[mypy-butterfree.transform.*] +ignore_errors = True + +[mypy-butterfree.extract.*] +ignore_errors = True + +[mypy-butterfree.config.*] +ignore_errors = True + +[mypy-butterfree.clients.*] +ignore_errors = True + +[mypy-butterfree.configs.*] +ignore_errors = True + +[mypy-butterfree.dataframe_service.*] +ignore_errors = True + +[mypy-butterfree.validations.*] +ignore_errors = True + +[mypy-butterfree.migrations.*] +ignore_errors = True + +[mypy-butterfree.testing.*] +ignore_errors = True + +[mypy-butterfree.hooks.*] +ignore_errors = True + +[mypy-butterfree._cli.*] +ignore_errors = True diff --git a/requirements.dev.txt b/requirements.dev.txt index 4e164c83f..89025669c 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,11 +1,11 @@ -h3==3.7.4 +h3==3.7.7 jupyter==1.0.0 twine==3.1.1 -mypy==0.790 +mypy==1.10.0 sphinx==3.5.4 sphinxemoji==0.1.8 sphinx-rtd-theme==0.5.2 recommonmark==0.7.1 -pyarrow>=1.0.0 -setuptools -wheel +pyarrow==16.1.0 +setuptools==70.0.0 +wheel==0.43.0 diff --git a/requirements.lint.txt b/requirements.lint.txt index 7c51f4b37..66641a952 100644 --- a/requirements.lint.txt +++ b/requirements.lint.txt @@ -1,8 +1,7 @@ -black==19.10b0 -flake8==3.7.9 -flake8-isort==2.8.0 -isort<5 # temporary fix +black==21.12b0 +flake8==4.0.1 +flake8-isort==4.1.1 flake8-docstrings==1.5.0 flake8-bugbear==20.1.0 -flake8-bandit==3.0.0 - +flake8-bandit==2.1.2 +bandit==1.7.2 diff --git a/requirements.test.txt b/requirements.test.txt index b0c4032a8..651700b80 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -2,4 +2,4 @@ pytest==5.3.2 pytest-cov==2.8.1 pytest-xdist==1.31.0 pytest-mock==2.0.0 -pytest-spark==0.5.2 +pytest-spark==0.6.0 diff --git a/requirements.txt b/requirements.txt index d61d125bc..f3af42540 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -cassandra-driver>=3.22.0,<4.0 +cassandra-driver==3.24.0 mdutils>=1.2.2,<2.0 pandas>=0.24,<2.0 parameters-validation>=1.1.5,<2.0 -pyspark==3.* -typer>=0.3,<0.4 +pyspark==3.5.1 +typer==0.3.2 typing-extensions>3.7.4,<5 boto3==1.17.* diff --git a/setup.cfg b/setup.cfg index c58c2df3e..849d35cf3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,13 +10,13 @@ per-file-ignores = setup.py:D,S101 [isort] +profile = black line_length = 88 known_first_party = butterfree default_section = THIRDPARTY multi_line_output = 3 indent = ' ' skip_glob = pip -use_parantheses = True include_trailing_comma = True [tool:pytest] diff --git a/setup.py b/setup.py index 6fa35751d..42ef57c85 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ author="QuintoAndar", install_requires=requirements, extras_require={"h3": ["h3>=3.7.4,<4"]}, - python_requires=">=3.7, <4", + python_requires=">=3.9, <4", entry_points={"console_scripts": ["butterfree=butterfree._cli.main:app"]}, include_package_data=True, ) diff --git a/tests/integration/butterfree/extract/test_source.py b/tests/integration/butterfree/extract/test_source.py index c465ebd05..3ab991ab2 100644 --- a/tests/integration/butterfree/extract/test_source.py +++ b/tests/integration/butterfree/extract/test_source.py @@ -1,11 +1,11 @@ from typing import List from pyspark.sql import DataFrame -from tests.integration import INPUT_PATH from butterfree.clients import SparkClient from butterfree.extract import Source from butterfree.extract.readers import FileReader, TableReader +from tests.integration import INPUT_PATH def create_temp_view(dataframe: DataFrame, name): @@ -13,10 +13,11 @@ def create_temp_view(dataframe: DataFrame, name): def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table): - spark.sql(f"create database if not exists {table_reader_db}") + spark.sql(f"drop schema if exists {table_reader_db} cascade") + spark.sql(f"create database {table_reader_db}") spark.sql(f"use {table_reader_db}") spark.sql( - f"create table if not exists {table_reader_db}.{table_reader_table} " # noqa + f"create table {table_reader_db}.{table_reader_table} " # noqa f"as select * from {table_reader_id}" # noqa ) @@ -33,7 +34,10 @@ def compare_dataframes( class TestSource: def test_source( - self, target_df_source, target_df_table_reader, spark_session, + self, + target_df_source, + target_df_table_reader, + spark_session, ): # given spark_client = SparkClient() @@ -66,6 +70,7 @@ def test_source( query=f"select a.*, b.feature2 " # noqa f"from {table_reader_id} a " # noqa f"inner join {file_reader_id} b on a.id = b.id ", # noqa + eager_evaluation=False, ) result_df = source.construct(client=spark_client) diff --git a/tests/integration/butterfree/load/conftest.py b/tests/integration/butterfree/load/conftest.py index 418b6d2ac..60101f1ac 100644 --- a/tests/integration/butterfree/load/conftest.py +++ b/tests/integration/butterfree/load/conftest.py @@ -51,7 +51,7 @@ def feature_set(): ] ts_feature = TimestampFeature(from_column="timestamp") features = [ - Feature(name="feature", description="Description", dtype=DataType.FLOAT), + Feature(name="feature", description="Description", dtype=DataType.INTEGER), ] return FeatureSet( "test_sink_feature_set", diff --git a/tests/integration/butterfree/load/test_sink.py b/tests/integration/butterfree/load/test_sink.py index b5f97879b..f73f5f7ce 100644 --- a/tests/integration/butterfree/load/test_sink.py +++ b/tests/integration/butterfree/load/test_sink.py @@ -24,10 +24,13 @@ def test_sink(input_dataframe, feature_set): s3config.mode = "overwrite" s3config.format_ = "parquet" s3config.get_options = Mock( - return_value={"path": "test_folder/historical/entity/feature_set"} + return_value={ + "path": "test_folder/historical/entity/feature_set", + "mode": "overwrite", + } ) s3config.get_path_with_partitions = Mock( - return_value="test_folder/historical/entity/feature_set" + return_value="spark-warehouse/test.db/test_folder/historical/entity/feature_set" ) historical_writer = HistoricalFeatureStoreWriter( diff --git a/tests/integration/butterfree/pipelines/conftest.py b/tests/integration/butterfree/pipelines/conftest.py index 73da163e6..5f304972d 100644 --- a/tests/integration/butterfree/pipelines/conftest.py +++ b/tests/integration/butterfree/pipelines/conftest.py @@ -132,7 +132,8 @@ def fixed_windows_output_feature_set_date_dataframe(spark_context, spark_session @pytest.fixture() def feature_set_pipeline( - spark_context, spark_session, + spark_context, + spark_session, ): feature_set_pipeline = FeatureSetPipeline( diff --git a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py index d67e0a387..791253398 100644 --- a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py @@ -50,10 +50,11 @@ def create_temp_view(dataframe: DataFrame, name): def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table): - spark.sql(f"create database if not exists {table_reader_db}") + spark.sql(f"drop schema {table_reader_db} cascade") + spark.sql(f"create database {table_reader_db}") spark.sql(f"use {table_reader_db}") spark.sql( - f"create table if not exists {table_reader_db}.{table_reader_table} " # noqa + f"create table {table_reader_db}.{table_reader_table} " # noqa f"as select * from {table_reader_id}" # noqa ) @@ -74,7 +75,10 @@ def create_ymd(dataframe): class TestFeatureSetPipeline: def test_feature_set_pipeline( - self, mocked_df, spark_session, fixed_windows_output_feature_set_dataframe, + self, + mocked_df, + spark_session, + fixed_windows_output_feature_set_dataframe, ): # arrange @@ -90,7 +94,7 @@ def test_feature_set_pipeline( table_reader_table=table_reader_table, ) - path = "test_folder/historical/entity/feature_set" + path = "spark-warehouse/test.db/test_folder/historical/entity/feature_set" dbconfig = MetastoreConfig() dbconfig.get_options = Mock( @@ -138,7 +142,9 @@ def test_feature_set_pipeline( description="unit test", dtype=DataType.FLOAT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ), ], @@ -237,7 +243,12 @@ def test_pipeline_with_hooks(self, spark_session): test_pipeline = FeatureSetPipeline( source=Source( - readers=[TableReader(id="reader", table="test",).add_post_hook(hook1)], + readers=[ + TableReader( + id="reader", + table="test", + ).add_post_hook(hook1) + ], query="select * from reader", ).add_post_hook(hook1), feature_set=FeatureSet( @@ -263,7 +274,9 @@ def test_pipeline_with_hooks(self, spark_session): ) .add_pre_hook(hook1) .add_post_hook(hook1), - sink=Sink(writers=[historical_writer],).add_pre_hook(hook1), + sink=Sink( + writers=[historical_writer], + ).add_pre_hook(hook1), ) # act @@ -325,11 +338,13 @@ def test_pipeline_interval_run( db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE") path = "test_folder/historical/entity/feature_set" + read_path = "spark-warehouse/test.db/" + path spark_session.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") - spark_session.sql(f"create database if not exists {db}") + spark_session.sql(f"drop schema {db} cascade") + spark_session.sql(f"create database {db}") spark_session.sql( - f"create table if not exists {db}.feature_set_interval " + f"create table {db}.feature_set_interval " f"(id int, timestamp timestamp, feature int, " f"run_id int, year int, month int, day int);" ) @@ -340,7 +355,7 @@ def test_pipeline_interval_run( ) historical_writer = HistoricalFeatureStoreWriter( - db_config=dbconfig, interval_mode=True + db_config=dbconfig, interval_mode=True, row_count_validation=False ) first_run_hook = RunHook(id=1) @@ -356,9 +371,10 @@ def test_pipeline_interval_run( test_pipeline = FeatureSetPipeline( source=Source( readers=[ - TableReader(id="id", table="input_data",).with_incremental_strategy( - IncrementalStrategy("ts") - ), + TableReader( + id="id", + table="input_data", + ).with_incremental_strategy(IncrementalStrategy("ts")), ], query="select * from id ", ), @@ -366,48 +382,56 @@ def test_pipeline_interval_run( name="feature_set_interval", entity="entity", description="", - keys=[KeyFeature(name="id", description="", dtype=DataType.INTEGER,)], + keys=[ + KeyFeature( + name="id", + description="", + dtype=DataType.INTEGER, + ) + ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature(name="feature", description="", dtype=DataType.INTEGER), Feature(name="run_id", description="", dtype=DataType.INTEGER), ], ), - sink=Sink([historical_writer],), + sink=Sink( + [historical_writer], + ), ) # act and assert dbconfig.get_path_with_partitions = Mock( return_value=[ - "test_folder/historical/entity/feature_set/year=2016/month=4/day=11", - "test_folder/historical/entity/feature_set/year=2016/month=4/day=12", - "test_folder/historical/entity/feature_set/year=2016/month=4/day=13", + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=11", # noqa + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=12", # noqa + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=13", # noqa ] ) test_pipeline.feature_set.add_pre_hook(first_run_hook) test_pipeline.run(end_date="2016-04-13", start_date="2016-04-11") - first_run_output_df = spark_session.read.parquet(path) + first_run_output_df = spark_session.read.parquet(read_path) assert_dataframe_equality(first_run_output_df, first_run_target_df) dbconfig.get_path_with_partitions = Mock( return_value=[ - "test_folder/historical/entity/feature_set/year=2016/month=4/day=14", + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=14", # noqa ] ) test_pipeline.feature_set.add_pre_hook(second_run_hook) test_pipeline.run_for_date("2016-04-14") - second_run_output_df = spark_session.read.parquet(path) + second_run_output_df = spark_session.read.parquet(read_path) assert_dataframe_equality(second_run_output_df, second_run_target_df) dbconfig.get_path_with_partitions = Mock( return_value=[ - "test_folder/historical/entity/feature_set/year=2016/month=4/day=11", + "spark-warehouse/test.db/test_folder/historical/entity/feature_set/year=2016/month=4/day=11", # noqa ] ) test_pipeline.feature_set.add_pre_hook(third_run_hook) test_pipeline.run_for_date("2016-04-11") - third_run_output_df = spark_session.read.parquet(path) + third_run_output_df = spark_session.read.parquet(read_path) assert_dataframe_equality(third_run_output_df, third_run_target_df) # tear down - shutil.rmtree("test_folder") + shutil.rmtree("spark-warehouse/test.db/test_folder") diff --git a/tests/integration/butterfree/transform/test_aggregated_feature_set.py b/tests/integration/butterfree/transform/test_aggregated_feature_set.py index bc3ebb6c7..413077619 100644 --- a/tests/integration/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/integration/butterfree/transform/test_aggregated_feature_set.py @@ -19,7 +19,9 @@ def divide(df, fs, column1, column2): class TestAggregatedFeatureSet: def test_construct_without_window( - self, feature_set_dataframe, target_df_without_window, + self, + feature_set_dataframe, + target_df_without_window, ): # given @@ -157,7 +159,9 @@ def test_construct_rolling_windows_without_end_date( ) ], timestamp=TimestampFeature(), - ).with_windows(definitions=["1 day", "1 week"],) + ).with_windows( + definitions=["1 day", "1 week"], + ) # act & assert with pytest.raises(ValueError): @@ -201,7 +205,9 @@ def test_h3_feature_set(self, h3_input_df, h3_target_df): assert_dataframe_equality(output_df, h3_target_df) def test_construct_with_pivot( - self, feature_set_df_pivot, target_df_pivot_agg, + self, + feature_set_df_pivot, + target_df_pivot_agg, ): # given @@ -243,7 +249,9 @@ def test_construct_with_pivot( assert_dataframe_equality(output_df, target_df_pivot_agg) def test_construct_rolling_windows_with_date_boundaries( - self, feature_set_dates_dataframe, rolling_windows_output_date_boundaries, + self, + feature_set_dates_dataframe, + rolling_windows_output_date_boundaries, ): # given diff --git a/tests/integration/butterfree/transform/test_feature_set.py b/tests/integration/butterfree/transform/test_feature_set.py index 25f70b6e2..6c5f7f1d8 100644 --- a/tests/integration/butterfree/transform/test_feature_set.py +++ b/tests/integration/butterfree/transform/test_feature_set.py @@ -51,7 +51,9 @@ def test_construct( description="unit test", dtype=DataType.FLOAT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ), ], @@ -92,7 +94,11 @@ def test_construct_with_date_boundaries( entity="entity", description="description", features=[ - Feature(name="feature", description="test", dtype=DataType.FLOAT,), + Feature( + name="feature", + description="test", + dtype=DataType.FLOAT, + ), ], keys=[ KeyFeature( diff --git a/tests/mocks/entities/first/first_pipeline.py b/tests/mocks/entities/first/first_pipeline.py index 90cfba96f..938c880c7 100644 --- a/tests/mocks/entities/first/first_pipeline.py +++ b/tests/mocks/entities/first/first_pipeline.py @@ -15,7 +15,13 @@ class FirstPipeline(FeatureSetPipeline): def __init__(self): super(FirstPipeline, self).__init__( source=Source( - readers=[TableReader(id="t", database="db", table="table",)], + readers=[ + TableReader( + id="t", + database="db", + table="table", + ) + ], query=f"select * from t", # noqa ), feature_set=FeatureSet( @@ -23,7 +29,11 @@ def __init__(self): entity="entity", description="description", features=[ - Feature(name="feature1", description="test", dtype=DataType.FLOAT,), + Feature( + name="feature1", + description="test", + dtype=DataType.FLOAT, + ), Feature( name="feature2", description="another test", @@ -32,7 +42,9 @@ def __init__(self): ], keys=[ KeyFeature( - name="id", description="identifier", dtype=DataType.BIGINT, + name="id", + description="identifier", + dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), diff --git a/tests/mocks/entities/second/deeper/second_pipeline.py b/tests/mocks/entities/second/deeper/second_pipeline.py index 12c53cf30..a59ba2e5d 100644 --- a/tests/mocks/entities/second/deeper/second_pipeline.py +++ b/tests/mocks/entities/second/deeper/second_pipeline.py @@ -15,7 +15,13 @@ class SecondPipeline(FeatureSetPipeline): def __init__(self): super(SecondPipeline, self).__init__( source=Source( - readers=[TableReader(id="t", database="db", table="table",)], + readers=[ + TableReader( + id="t", + database="db", + table="table", + ) + ], query=f"select * from t", # noqa ), feature_set=FeatureSet( @@ -24,7 +30,9 @@ def __init__(self): description="description", features=[ Feature( - name="feature1", description="test", dtype=DataType.STRING, + name="feature1", + description="test", + dtype=DataType.STRING, ), Feature( name="feature2", @@ -34,7 +42,9 @@ def __init__(self): ], keys=[ KeyFeature( - name="id", description="identifier", dtype=DataType.BIGINT, + name="id", + description="identifier", + dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), diff --git a/tests/unit/butterfree/clients/test_spark_client.py b/tests/unit/butterfree/clients/test_spark_client.py index 12d8ac9d6..b2418a7c6 100644 --- a/tests/unit/butterfree/clients/test_spark_client.py +++ b/tests/unit/butterfree/clients/test_spark_client.py @@ -69,7 +69,8 @@ def test_read( assert target_df.collect() == result_df.collect() @pytest.mark.parametrize( - "format, path", [(None, "path/to/file"), ("csv", 123)], + "format, path", + [(None, "path/to/file"), ("csv", 123)], ) def test_read_invalid_params(self, format: Optional[str], path: Any) -> None: # arrange @@ -115,7 +116,8 @@ def test_read_table( assert target_df == result_df @pytest.mark.parametrize( - "database, table", [("database", None), ("database", 123)], + "database, table", + [("database", None), ("database", 123)], ) def test_read_table_invalid_params( self, database: str, table: Optional[int] @@ -128,7 +130,8 @@ def test_read_table_invalid_params( spark_client.read_table(table, database) # type: ignore @pytest.mark.parametrize( - "format, mode", [("parquet", "append"), ("csv", "overwrite")], + "format, mode", + [("parquet", "append"), ("csv", "overwrite")], ) def test_write_dataframe( self, format: str, mode: str, mocked_spark_write: Mock @@ -137,7 +140,8 @@ def test_write_dataframe( mocked_spark_write.save.assert_called_with(format=format, mode=mode) @pytest.mark.parametrize( - "format, mode", [(None, "append"), ("parquet", 1)], + "format, mode", + [(None, "append"), ("parquet", 1)], ) def test_write_dataframe_invalid_params( self, target_df: DataFrame, format: Optional[str], mode: Union[str, int] @@ -266,7 +270,7 @@ def test_create_temporary_view( def test_add_table_partitions(self, mock_spark_sql: Mock): # arrange target_command = ( - f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS " + f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS " # noqa f"PARTITION ( year = 2020, month = 8, day = 14 ) " f"PARTITION ( year = 2020, month = 8, day = 15 ) " f"PARTITION ( year = 2020, month = 8, day = 16 )" diff --git a/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py b/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py index 669fd0336..fed20f2d4 100644 --- a/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py +++ b/tests/unit/butterfree/extract/pre_processing/test_filter_transform.py @@ -28,7 +28,8 @@ def test_filter(self, feature_set_dataframe, spark_context, spark_session): assert result_df.collect() == target_df.collect() @pytest.mark.parametrize( - "condition", [None, 100], + "condition", + [None, 100], ) def test_filter_with_invalidations( self, feature_set_dataframe, condition, spark_context, spark_session diff --git a/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py b/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py index e716f9d65..cfe730d3a 100644 --- a/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py +++ b/tests/unit/butterfree/extract/pre_processing/test_pivot_transform.py @@ -9,7 +9,9 @@ class TestPivotTransform: def test_pivot_transformation( - self, input_df, pivot_df, + self, + input_df, + pivot_df, ): result_df = pivot( dataframe=input_df, @@ -20,10 +22,15 @@ def test_pivot_transformation( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_df, + ) def test_pivot_transformation_with_forward_fill( - self, input_df, pivot_ffill_df, + self, + input_df, + pivot_ffill_df, ): result_df = pivot( dataframe=input_df, @@ -35,10 +42,15 @@ def test_pivot_transformation_with_forward_fill( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_ffill_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_ffill_df, + ) def test_pivot_transformation_with_forward_fill_and_mock( - self, input_df, pivot_ffill_mock_df, + self, + input_df, + pivot_ffill_mock_df, ): result_df = pivot( dataframe=input_df, @@ -52,10 +64,15 @@ def test_pivot_transformation_with_forward_fill_and_mock( ) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_ffill_mock_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_ffill_mock_df, + ) def test_pivot_transformation_mock_without_type( - self, input_df, pivot_ffill_mock_df, + self, + input_df, + pivot_ffill_mock_df, ): with pytest.raises(AttributeError): _ = pivot( @@ -83,4 +100,7 @@ def test_apply_pivot_transformation(self, input_df, pivot_df): result_df = file_reader._apply_transformations(input_df) # assert - assert compare_dataframes(actual_df=result_df, expected_df=pivot_df,) + assert compare_dataframes( + actual_df=result_df, + expected_df=pivot_df, + ) diff --git a/tests/unit/butterfree/extract/readers/test_file_reader.py b/tests/unit/butterfree/extract/readers/test_file_reader.py index 9e1c42bce..136c8fd62 100644 --- a/tests/unit/butterfree/extract/readers/test_file_reader.py +++ b/tests/unit/butterfree/extract/readers/test_file_reader.py @@ -7,7 +7,15 @@ class TestFileReader: @pytest.mark.parametrize( - "path, format", [(None, "parquet"), ("path/to/file.json", 123), (123, None,)], + "path, format", + [ + (None, "parquet"), + ("path/to/file.json", 123), + ( + 123, + None, + ), + ], ) def test_init_invalid_params(self, path, format): # act and assert diff --git a/tests/unit/butterfree/extract/readers/test_reader.py b/tests/unit/butterfree/extract/readers/test_reader.py index 78160553f..bcceacbd1 100644 --- a/tests/unit/butterfree/extract/readers/test_reader.py +++ b/tests/unit/butterfree/extract/readers/test_reader.py @@ -148,7 +148,8 @@ def test_build_with_columns( # act file_reader.build( - client=spark_client, columns=[("col1", "new_col1"), ("col2", "new_col2")], + client=spark_client, + columns=[("col1", "new_col1"), ("col2", "new_col2")], ) result_df = spark_session.sql("select * from test") diff --git a/tests/unit/butterfree/extract/readers/test_table_reader.py b/tests/unit/butterfree/extract/readers/test_table_reader.py index 65f3be236..1a2f56f23 100644 --- a/tests/unit/butterfree/extract/readers/test_table_reader.py +++ b/tests/unit/butterfree/extract/readers/test_table_reader.py @@ -5,7 +5,14 @@ class TestTableReader: @pytest.mark.parametrize( - "database, table", [("database", 123), (123, None,)], + "database, table", + [ + ("database", 123), + ( + 123, + None, + ), + ], ) def test_init_invalid_params(self, database, table): # act and assert diff --git a/tests/unit/butterfree/extract/test_source.py b/tests/unit/butterfree/extract/test_source.py index 53af8b658..842d2210f 100644 --- a/tests/unit/butterfree/extract/test_source.py +++ b/tests/unit/butterfree/extract/test_source.py @@ -14,7 +14,8 @@ def test_construct(self, mocker, target_df): # when source_selector = Source( - readers=[reader], query=f"select * from {reader_id}", # noqa + readers=[reader], + query=f"select * from {reader_id}", # noqa ) result_df = source_selector.construct(spark_client) @@ -32,7 +33,8 @@ def test_is_cached(self, mocker, target_df): # when source_selector = Source( - readers=[reader], query=f"select * from {reader_id}", # noqa + readers=[reader], + query=f"select * from {reader_id}", # noqa ) result_df = source_selector.construct(spark_client) diff --git a/tests/unit/butterfree/load/conftest.py b/tests/unit/butterfree/load/conftest.py index 4dcf25c94..d0bb2c3be 100644 --- a/tests/unit/butterfree/load/conftest.py +++ b/tests/unit/butterfree/load/conftest.py @@ -20,7 +20,11 @@ def feature_set(): ] ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) features = [ - Feature(name="feature", description="Description", dtype=DataType.BIGINT,) + Feature( + name="feature", + description="Description", + dtype=DataType.BIGINT, + ) ] return FeatureSet( "feature_set", diff --git a/tests/unit/butterfree/load/processing/test_json_transform.py b/tests/unit/butterfree/load/processing/test_json_transform.py index 73949eea7..78320d108 100644 --- a/tests/unit/butterfree/load/processing/test_json_transform.py +++ b/tests/unit/butterfree/load/processing/test_json_transform.py @@ -3,7 +3,9 @@ class TestJsonTransform: def test_json_transformation( - self, input_df, json_df, + self, + input_df, + json_df, ): result_df = json_transform(dataframe=input_df) diff --git a/tests/unit/butterfree/migrations/database_migration/conftest.py b/tests/unit/butterfree/migrations/database_migration/conftest.py index dcd96714f..237158b7b 100644 --- a/tests/unit/butterfree/migrations/database_migration/conftest.py +++ b/tests/unit/butterfree/migrations/database_migration/conftest.py @@ -45,10 +45,18 @@ def feature_set(): entity="entity", description="description", features=[ - Feature(name="feature_float", description="test", dtype=DataType.FLOAT,), + Feature( + name="feature_float", + description="test", + dtype=DataType.FLOAT, + ), ], keys=[ - KeyFeature(name="id", description="The device ID", dtype=DataType.BIGINT,) + KeyFeature( + name="id", + description="The device ID", + dtype=DataType.BIGINT, + ) ], timestamp=TimestampFeature(), ) diff --git a/tests/unit/butterfree/pipelines/conftest.py b/tests/unit/butterfree/pipelines/conftest.py index 47e65efb7..f17e5f41e 100644 --- a/tests/unit/butterfree/pipelines/conftest.py +++ b/tests/unit/butterfree/pipelines/conftest.py @@ -23,7 +23,13 @@ def feature_set_pipeline(): spark_client=SparkClient(), source=Mock( spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], + readers=[ + TableReader( + id="source_a", + database="db", + table="table", + ) + ], query="select * from source_a", ), feature_set=Mock( @@ -57,7 +63,10 @@ def feature_set_pipeline(): ), ], ), - sink=Mock(spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],), + sink=Mock( + spec=Sink, + writers=[HistoricalFeatureStoreWriter(db_config=None)], + ), ) return test_pipeline diff --git a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py index 7bae6606b..5a67e77d4 100644 --- a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py @@ -22,6 +22,20 @@ from butterfree.transform.utils import Function +def get_reader(): + table_reader = TableReader( + id="source_a", + database="db", + table="table", + ) + + return table_reader + + +def get_historical_writer(): + return HistoricalFeatureStoreWriter(db_config=None) + + class TestFeatureSetPipeline: def test_feature_set_args(self): # arrange and act @@ -38,8 +52,12 @@ def test_feature_set_args(self): pipeline = FeatureSetPipeline( source=Source( readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), + get_reader(), + FileReader( + id="source_b", + path="path", + format="parquet", + ), ], query="select a.*, b.specific_feature " "from source_a left join source_b on a.id=b.id", @@ -131,7 +149,7 @@ def test_source_raise(self): source=Mock( spark_client=SparkClient(), readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -167,7 +185,8 @@ def test_source_raise(self): ], ), sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], + spec=Sink, + writers=[get_historical_writer()], ), ) @@ -180,7 +199,7 @@ def test_feature_set_raise(self): source=Mock( spec=Source, readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -215,7 +234,8 @@ def test_feature_set_raise(self): ], ), sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], + spec=Sink, + writers=[get_historical_writer()], ), ) @@ -226,7 +246,7 @@ def test_sink_raise(self): source=Mock( spec=Source, readers=[ - TableReader(id="source_a", database="db", table="table",), + get_reader(), ], query="select * from source_a", ), @@ -250,7 +270,9 @@ def test_sink_raise(self): key_columns=["user_id"], timestamp_column="ts", ), - sink=Mock(writers=[HistoricalFeatureStoreWriter(db_config=None)],), + sink=Mock( + writers=[get_historical_writer()], + ), ) def test_run_agg_with_end_date(self, spark_session, feature_set_pipeline): diff --git a/tests/unit/butterfree/reports/test_metadata.py b/tests/unit/butterfree/reports/test_metadata.py index 6f26cc558..093721df1 100644 --- a/tests/unit/butterfree/reports/test_metadata.py +++ b/tests/unit/butterfree/reports/test_metadata.py @@ -16,49 +16,63 @@ from butterfree.transform.utils import Function +def get_pipeline(): + + return FeatureSetPipeline( + source=Source( + readers=[ + TableReader( + id="source_a", + database="db", + table="table", + ), + FileReader( + id="source_b", + path="path", + format="parquet", + ), + ], + query="select a.*, b.specific_feature " + "from source_a left join source_b on a.id=b.id", + ), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + keys=[ + KeyFeature( + name="user_id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature( + name="page_viewed__rent_per_month", + description="Average of something.", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.DOUBLE), + ], + ), + ), + ], + ), + sink=Sink( + writers=[ + HistoricalFeatureStoreWriter(db_config=None), + OnlineFeatureStoreWriter(db_config=None), + ], + ), + ) + + class TestMetadata: def test_json(self): - pipeline = FeatureSetPipeline( - source=Source( - readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), - ], - query="select a.*, b.specific_feature " - "from source_a left join source_b on a.id=b.id", - ), - feature_set=FeatureSet( - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.DOUBLE), - ], - ), - ), - ], - ), - sink=Sink( - writers=[ - HistoricalFeatureStoreWriter(db_config=None), - OnlineFeatureStoreWriter(db_config=None), - ], - ), - ) + + pipeline = get_pipeline() target_json = [ { @@ -102,47 +116,8 @@ def test_json(self): assert json == target_json def test_markdown(self): - pipeline = FeatureSetPipeline( - source=Source( - readers=[ - TableReader(id="source_a", database="db", table="table",), - FileReader(id="source_b", path="path", format="parquet",), - ], - query="select a.*, b.specific_feature " - "from source_a left join source_b on a.id=b.id", - ), - feature_set=FeatureSet( - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.DOUBLE), - ], - ), - ), - ], - ), - sink=Sink( - writers=[ - HistoricalFeatureStoreWriter(db_config=None), - OnlineFeatureStoreWriter(db_config=None), - ], - ), - ) + + pipeline = get_pipeline() target_md = ( "\n# Feature_set\n\n## Description\n\n\ndescription \n\n" diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py index ab7606407..fcf601328 100644 --- a/tests/unit/butterfree/transform/conftest.py +++ b/tests/unit/butterfree/transform/conftest.py @@ -455,6 +455,12 @@ def agg_feature_set(): ), ), ], - keys=[KeyFeature(name="id", description="description", dtype=DataType.BIGINT,)], + keys=[ + KeyFeature( + name="id", + description="description", + dtype=DataType.BIGINT, + ) + ], timestamp=TimestampFeature(), ) diff --git a/tests/unit/butterfree/transform/features/test_feature.py b/tests/unit/butterfree/transform/features/test_feature.py index 14a89f2cf..01bb41e5a 100644 --- a/tests/unit/butterfree/transform/features/test_feature.py +++ b/tests/unit/butterfree/transform/features/test_feature.py @@ -98,7 +98,9 @@ def test_feature_transform_with_from_column_and_column_name_exists( def test_feature_transform_with_dtype(self, feature_set_dataframe): test_feature = Feature( - name="feature", description="unit test", dtype=DataType.TIMESTAMP, + name="feature", + description="unit test", + dtype=DataType.TIMESTAMP, ) df = test_feature.transform(feature_set_dataframe) diff --git a/tests/unit/butterfree/transform/test_aggregated_feature_set.py b/tests/unit/butterfree/transform/test_aggregated_feature_set.py index 73320cf57..38ec249aa 100644 --- a/tests/unit/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/unit/butterfree/transform/test_aggregated_feature_set.py @@ -44,7 +44,10 @@ def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe): ).construct(dataframe, spark_client) def test_agg_feature_set_with_window( - self, dataframe, rolling_windows_agg_dataframe, agg_feature_set, + self, + dataframe, + rolling_windows_agg_dataframe, + agg_feature_set, ): spark_client = SparkClient() @@ -61,7 +64,10 @@ def test_agg_feature_set_with_window( assert_dataframe_equality(output_df, rolling_windows_agg_dataframe) def test_agg_feature_set_with_smaller_slide( - self, dataframe, rolling_windows_hour_slide_agg_dataframe, agg_feature_set, + self, + dataframe, + rolling_windows_hour_slide_agg_dataframe, + agg_feature_set, ): spark_client = SparkClient() @@ -366,7 +372,9 @@ def test_define_start_date(self, agg_feature_set): assert start_date == "2020-07-27" def test_feature_set_start_date( - self, timestamp_c, feature_set_with_distinct_dataframe, + self, + timestamp_c, + feature_set_with_distinct_dataframe, ): fs = AggregatedFeatureSet( name="name", diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py index 43d937bec..e907dc0a8 100644 --- a/tests/unit/butterfree/transform/test_feature_set.py +++ b/tests/unit/butterfree/transform/test_feature_set.py @@ -3,12 +3,6 @@ import pytest from pyspark.sql import functions as F from pyspark.sql.types import DoubleType, FloatType, LongType, TimestampType -from tests.unit.butterfree.transform.conftest import ( - feature_add, - feature_divide, - key_id, - timestamp_c, -) from butterfree.clients import SparkClient from butterfree.constants import DataType @@ -20,6 +14,12 @@ SQLExpressionTransform, ) from butterfree.transform.utils import Function +from tests.unit.butterfree.transform.conftest import ( + feature_add, + feature_divide, + key_id, + timestamp_c, +) class TestFeatureSet: @@ -70,7 +70,14 @@ class TestFeatureSet: None, [feature_add, feature_divide], ), - ("name", "entity", "description", [key_id], timestamp_c, [None],), + ( + "name", + "entity", + "description", + [key_id], + timestamp_c, + [None], + ), ], ) def test_cannot_instantiate( diff --git a/tests/unit/butterfree/transform/transformations/conftest.py b/tests/unit/butterfree/transform/transformations/conftest.py index 8f3c13bff..41bc63d5c 100644 --- a/tests/unit/butterfree/transform/transformations/conftest.py +++ b/tests/unit/butterfree/transform/transformations/conftest.py @@ -62,7 +62,7 @@ def target_df_spark(spark_context, spark_session): "timestamp": "2016-04-11 11:31:11", "feature1": 200, "feature2": 200, - "feature__cos": 0.4871876750070059, + "feature__cos": 0.48718767500700594, }, { "id": 1, diff --git a/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py b/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py index 6cdebf74d..f0ae2f854 100644 --- a/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_aggregated_transform.py @@ -44,7 +44,10 @@ def test_output_columns(self): assert all( [ a == b - for a, b in zip(df_columns, ["feature1__avg", "feature1__stddev_pop"],) + for a, b in zip( + df_columns, + ["feature1__avg", "feature1__stddev_pop"], + ) ] ) diff --git a/tests/unit/butterfree/transform/transformations/test_custom_transform.py b/tests/unit/butterfree/transform/transformations/test_custom_transform.py index 4198d9bda..d87cc7cb1 100644 --- a/tests/unit/butterfree/transform/transformations/test_custom_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_custom_transform.py @@ -21,7 +21,9 @@ def test_feature_transform(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) @@ -44,7 +46,9 @@ def test_output_columns(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) @@ -59,7 +63,9 @@ def test_custom_transform_output(self, feature_set_dataframe): description="unit test", dtype=DataType.BIGINT, transformation=CustomTransform( - transformer=divide, column1="feature1", column2="feature2", + transformer=divide, + column1="feature1", + column2="feature2", ), ) diff --git a/tests/unit/butterfree/transform/transformations/test_h3_transform.py b/tests/unit/butterfree/transform/transformations/test_h3_transform.py index 4b3308ebe..d4ad6493e 100644 --- a/tests/unit/butterfree/transform/transformations/test_h3_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_h3_transform.py @@ -64,9 +64,9 @@ def test_import_error(self): for m in modules: del sys.modules[m] with pytest.raises(ModuleNotFoundError, match="you must install"): - from butterfree.transform.transformations.h3_transform import ( # noqa - H3HashTransform, # noqa - ) # noqa + from butterfree.transform.transformations.h3_transform import ( # noqa; noqa + H3HashTransform, + ) def test_with_stack(self, h3_input_df, h3_with_stack_target_df): # arrange diff --git a/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py b/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py index fe8bca85c..cf88657a0 100644 --- a/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_spark_function_transform.py @@ -126,7 +126,9 @@ def test_feature_transform_output_row_windows( transformation=SparkFunctionTransform( functions=[Function(functions.avg, DataType.DOUBLE)], ).with_window( - partition_by="id", mode="row_windows", window_definition=["2 events"], + partition_by="id", + mode="row_windows", + window_definition=["2 events"], ), ) diff --git a/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py b/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py index 9cc2e687e..814f83012 100644 --- a/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py +++ b/tests/unit/butterfree/transform/transformations/test_sql_expression_transform.py @@ -43,7 +43,15 @@ def test_output_columns(self): df_columns = test_feature.get_output_columns() - assert all([a == b for a, b in zip(df_columns, ["feature1_over_feature2"],)]) + assert all( + [ + a == b + for a, b in zip( + df_columns, + ["feature1_over_feature2"], + ) + ] + ) def test_feature_transform_output(self, feature_set_dataframe): test_feature = Feature(