From 0a532104aa6b82cab58fa54f9451a1336c0b07df Mon Sep 17 00:00:00 2001 From: Ralph Filho Date: Mon, 16 Sep 2024 09:25:23 -0300 Subject: [PATCH] fix: performance improvements --- butterfree/_cli/migrate.py | 14 +++++++-- butterfree/extract/source.py | 14 ++++++--- butterfree/pipelines/feature_set_pipeline.py | 29 ++++++++++++----- .../transform/aggregated_feature_set.py | 31 +++++++++++-------- butterfree/transform/feature_set.py | 7 ++--- .../butterfree/transform/test_feature_set.py | 2 +- 6 files changed, 64 insertions(+), 33 deletions(-) diff --git a/butterfree/_cli/migrate.py b/butterfree/_cli/migrate.py index f5161509..6bd5ca08 100644 --- a/butterfree/_cli/migrate.py +++ b/butterfree/_cli/migrate.py @@ -4,7 +4,7 @@ import os import pkgutil import sys -from typing import Set +from typing import Set, Type import boto3 import setuptools @@ -90,8 +90,18 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]: instances.add(value) + def create_instance(cls: Type[FeatureSetPipeline]) -> FeatureSetPipeline: + sig = inspect.signature(cls.__init__) + parameters = sig.parameters + + if "run_date" in parameters: + run_date = datetime.datetime.today().strftime("%y-%m-%d") + return cls(run_date) + + return cls() + logger.info("Creating instances...") - return set(value() for value in instances) # type: ignore + return set(create_instance(value) for value in instances) # type: ignore PATH = typer.Argument( diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py index bfc15271..9d50e94c 100644 --- a/butterfree/extract/source.py +++ b/butterfree/extract/source.py @@ -3,6 +3,7 @@ from typing import List, Optional from pyspark.sql import DataFrame +from pyspark.storagelevel import StorageLevel from butterfree.clients import SparkClient from butterfree.extract.readers.reader import Reader @@ -95,16 +96,21 @@ def construct( DataFrame with the query result against all readers. """ + # Step 1: Build temporary views for each reader for reader in self.readers: - reader.build( - client=client, start_date=start_date, end_date=end_date - ) # create temporary views for each reader + reader.build(client=client, start_date=start_date, end_date=end_date) + # Step 2: Execute SQL query on the combined readers dataframe = client.sql(self.query) + # Step 3: Cache the dataframe if necessary, using memory and disk storage if not dataframe.isStreaming and self.eager_evaluation: - dataframe.cache().count() + # Persist to ensure the DataFrame is stored in mem and disk (if necessary) + dataframe.persist(StorageLevel.MEMORY_AND_DISK) + # Trigger the cache/persist operation by performing an action + dataframe.count() + # Step 4: Run post-processing hooks on the dataframe post_hook_df = self.run_post_hooks(dataframe) return post_hook_df diff --git a/butterfree/pipelines/feature_set_pipeline.py b/butterfree/pipelines/feature_set_pipeline.py index 8ba1a636..d57459f3 100644 --- a/butterfree/pipelines/feature_set_pipeline.py +++ b/butterfree/pipelines/feature_set_pipeline.py @@ -2,6 +2,8 @@ from typing import List, Optional +from pyspark.storagelevel import StorageLevel + from butterfree.clients import SparkClient from butterfree.dataframe_service import repartition_sort_df from butterfree.extract import Source @@ -209,19 +211,25 @@ def run( soon. Use only if strictly necessary. """ + # Step 1: Construct input dataframe from the source. dataframe = self.source.construct( client=self.spark_client, start_date=self.feature_set.define_start_date(start_date), end_date=end_date, ) + # Step 2: Repartition and sort if required, avoid if not necessary. if partition_by: order_by = order_by or partition_by - dataframe = repartition_sort_df( - dataframe, partition_by, order_by, num_processors - ) - - dataframe = self.feature_set.construct( + current_partitions = dataframe.rdd.getNumPartitions() + optimal_partitions = num_processors or current_partitions + if current_partitions != optimal_partitions: + dataframe = repartition_sort_df( + dataframe, partition_by, order_by, num_processors + ) + + # Step 3: Construct the feature set dataframe using defined transformations. + transformed_dataframe = self.feature_set.construct( dataframe=dataframe, client=self.spark_client, start_date=start_date, @@ -229,15 +237,20 @@ def run( num_processors=num_processors, ) + if dataframe.storageLevel != StorageLevel.NONE: + dataframe.unpersist() # Clear the data from the cache (disk and memory) + + # Step 4: Load the data into the configured sink. self.sink.flush( - dataframe=dataframe, + dataframe=transformed_dataframe, feature_set=self.feature_set, spark_client=self.spark_client, ) - if not dataframe.isStreaming: + # Step 5: Validate the output if not streaming and data volume is reasonable. + if not transformed_dataframe.isStreaming: self.sink.validate( - dataframe=dataframe, + dataframe=transformed_dataframe, feature_set=self.feature_set, spark_client=self.spark_client, ) diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py index 6706bf8c..9f55ae93 100644 --- a/butterfree/transform/aggregated_feature_set.py +++ b/butterfree/transform/aggregated_feature_set.py @@ -387,6 +387,7 @@ def _aggregate( ] groupby = self.keys_columns.copy() + if window is not None: dataframe = dataframe.withColumn("window", window.get()) groupby.append("window") @@ -410,19 +411,23 @@ def _aggregate( "keep_rn", functions.row_number().over(partition_window) ).filter("keep_rn = 1") - # repartition to have all rows for each group at the same partition - # by doing that, we won't have to shuffle data on grouping by id - dataframe = repartition_df( - dataframe, - partition_by=groupby, - num_processors=num_processors, - ) + current_partitions = dataframe.rdd.getNumPartitions() + optimal_partitions = num_processors or current_partitions + + if current_partitions != optimal_partitions: + dataframe = repartition_df( + dataframe, + partition_by=groupby, + num_processors=optimal_partitions, + ) + grouped_data = dataframe.groupby(*groupby) - if self._pivot_column: + if self._pivot_column and self._pivot_values: grouped_data = grouped_data.pivot(self._pivot_column, self._pivot_values) aggregated = grouped_data.agg(*aggregations) + return self._with_renamed_columns(aggregated, features, window) def _with_renamed_columns( @@ -637,12 +642,12 @@ def construct( output_df = output_df.select(*self.columns).replace( # type: ignore float("nan"), None ) - if not output_df.isStreaming: - if self.deduplicate_rows: - output_df = self._filter_duplicated_rows(output_df) - if self.eager_evaluation: - output_df.cache().count() + if not output_df.isStreaming and self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) post_hook_df = self.run_post_hooks(output_df) + if not output_df.isStreaming and self.eager_evaluation: + post_hook_df.cache().count() + return post_hook_df diff --git a/butterfree/transform/feature_set.py b/butterfree/transform/feature_set.py index 369eaf29..2c4b9b51 100644 --- a/butterfree/transform/feature_set.py +++ b/butterfree/transform/feature_set.py @@ -436,11 +436,8 @@ def construct( pre_hook_df, ).select(*self.columns) - if not output_df.isStreaming: - if self.deduplicate_rows: - output_df = self._filter_duplicated_rows(output_df) - if self.eager_evaluation: - output_df.cache().count() + if not output_df.isStreaming and self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) output_df = self.incremental_strategy.filter_with_incremental_strategy( dataframe=output_df, start_date=start_date, end_date=end_date diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py index e907dc0a..37a69be2 100644 --- a/tests/unit/butterfree/transform/test_feature_set.py +++ b/tests/unit/butterfree/transform/test_feature_set.py @@ -220,7 +220,7 @@ def test_construct( + feature_divide.get_output_columns() ) assert_dataframe_equality(result_df, feature_set_dataframe) - assert result_df.is_cached + assert not result_df.is_cached def test_construct_invalid_df( self, key_id, timestamp_c, feature_add, feature_divide