diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d33e4aa0..8b4d9c73 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -4,11 +4,9 @@ on: paths: - 'setup.py' - jobs: Pipeline: if: github.ref == 'refs/heads/master' - runs-on: ubuntu-latest steps: @@ -19,7 +17,7 @@ jobs: - uses: actions/setup-java@v4 with: - java-version: '11' + java-version: '17' distribution: microsoft - uses: vemonet/setup-spark@v1 diff --git a/.github/workflows/staging.yml b/.github/workflows/staging.yml index 573049ca..9885ba68 100644 --- a/.github/workflows/staging.yml +++ b/.github/workflows/staging.yml @@ -7,7 +7,6 @@ on: jobs: Pipeline: if: github.ref == 'refs/heads/staging' - runs-on: ubuntu-latest steps: @@ -18,7 +17,7 @@ jobs: - uses: actions/setup-java@v4 with: - java-version: '11' + java-version: '17' distribution: microsoft - uses: vemonet/setup-spark@v1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d588c853..96ad666f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/setup-java@v4 with: - java-version: '11' + java-version: '17' distribution: microsoft - uses: vemonet/setup-spark@v1 diff --git a/CHANGELOG.md b/CHANGELOG.md index fe9f9a8a..19d9b5f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each ## [Unreleased] +## [1.4.0](https://github.com/quintoandar/butterfree/releases/tag/1.4.0) +* Add Delta support ([#370](https://github.com/quintoandar/butterfree/pull/370)) + ## [1.3.5](https://github.com/quintoandar/butterfree/releases/tag/1.3.5) * Auto create feature sets ([#368](https://github.com/quintoandar/butterfree/pull/368)) diff --git a/Makefile b/Makefile index db9b561b..a93104ab 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ minimum-requirements: .PHONY: requirements ## install all requirements -requirements: requirements-test requirements-lint dev-requirements minimum-requirements +requirements: minimum-requirements dev-requirements requirements-test requirements-lint .PHONY: ci-install ci-install: diff --git a/butterfree/clients/spark_client.py b/butterfree/clients/spark_client.py index 933c2165..f4b6ea65 100644 --- a/butterfree/clients/spark_client.py +++ b/butterfree/clients/spark_client.py @@ -30,6 +30,7 @@ def conn(self) -> SparkSession: """ if not self._session: self._session = SparkSession.builder.getOrCreate() + return self._session def read( diff --git a/butterfree/load/writers/__init__.py b/butterfree/load/writers/__init__.py index 72945d27..f1f0e449 100644 --- a/butterfree/load/writers/__init__.py +++ b/butterfree/load/writers/__init__.py @@ -1,8 +1,9 @@ """Holds data loaders for historical and online feature store.""" +from butterfree.load.writers.delta_writer import DeltaWriter from butterfree.load.writers.historical_feature_store_writer import ( HistoricalFeatureStoreWriter, ) from butterfree.load.writers.online_feature_store_writer import OnlineFeatureStoreWriter -__all__ = ["HistoricalFeatureStoreWriter", "OnlineFeatureStoreWriter"] +__all__ = ["HistoricalFeatureStoreWriter", "OnlineFeatureStoreWriter", "DeltaWriter"] diff --git a/butterfree/load/writers/delta_writer.py b/butterfree/load/writers/delta_writer.py new file mode 100644 index 00000000..933f1adb --- /dev/null +++ b/butterfree/load/writers/delta_writer.py @@ -0,0 +1,162 @@ +from delta.tables import DeltaTable +from pyspark.sql.dataframe import DataFrame + +from butterfree.clients import SparkClient +from butterfree.configs.logger import __logger + +logger = __logger("delta_writer", True) + + +class DeltaWriter: + """Control operations on Delta Tables. + + Resposible for merging and optimizing. + """ + + @staticmethod + def _get_full_table_name(table, database): + if database: + return "{}.{}".format(database, table) + else: + return table + + @staticmethod + def _convert_to_delta(client: SparkClient, table: str): + logger.info(f"Converting {table} to Delta...") + client.conn.sql(f"CONVERT TO DELTA {table}") + logger.info("Conversion done.") + + @staticmethod + def merge( + client: SparkClient, + database: str, + table: str, + merge_on: list, + source_df: DataFrame, + when_not_matched_insert_condition: str = None, + when_matched_update_condition: str = None, + when_matched_delete_condition: str = None, + ): + """ + Merge a source dataframe to a Delta table. + + By default, it will update when matched, and insert when + not matched (simple upsert). + + You can change this behavior by setting: + - when_not_matched_insert_condition: it will only insert + when this specified condition is true + - when_matched_update_condition: it will only update when this + specified condition is true. You can refer to the columns + in the source dataframe as source., and the columns + in the target table as target.. + - when_matched_delete_condition: it will add an operation to delete, + but only if this condition is true. Again, source and + target dataframe columns can be referred to respectively as + source. and target. + """ + try: + full_table_name = DeltaWriter._get_full_table_name(table, database) + + table_exists = client.conn.catalog.tableExists(full_table_name) + + if table_exists: + pd_df = client.conn.sql( + f"DESCRIBE TABLE EXTENDED {full_table_name}" + ).toPandas() + provider = ( + pd_df.reset_index() + .groupby(["col_name"])["data_type"] + .aggregate("first") + .Provider + ) + table_is_delta = provider.lower() == "delta" + + if not table_is_delta: + DeltaWriter()._convert_to_delta(client, full_table_name) + + # For schema evolution + client.conn.conf.set( + "spark.databricks.delta.schema.autoMerge.enabled", "true" + ) + + target_table = DeltaTable.forName(client.conn, full_table_name) + join_condition = " AND ".join( + [f"source.{col} = target.{col}" for col in merge_on] + ) + merge_builder = target_table.alias("target").merge( + source_df.alias("source"), join_condition + ) + if when_matched_delete_condition: + merge_builder = merge_builder.whenMatchedDelete( + condition=when_matched_delete_condition + ) + + merge_builder.whenMatchedUpdateAll( + condition=when_matched_update_condition + ).whenNotMatchedInsertAll( + condition=when_not_matched_insert_condition + ).execute() + except Exception as e: + logger.error(f"Merge operation on {full_table_name} failed: {e}") + + @staticmethod + def vacuum(table: str, retention_hours: int, client: SparkClient): + """Vacuum a Delta table. + + Vacuum remove unused files (files not managed by Delta + files + that are not in the latest state). + After vacuum it's impossible to time travel to versions + older than the `retention` time. + Default retention is 7 days. Lower retentions will be warned, + unless it's set to false. + Set spark.databricks.delta.retentionDurationCheck.enabled + to false for low retentions. + https://docs.databricks.com/en/sql/language-manual/delta-vacuum.html + """ + + command = f"VACUUM {table} RETAIN {retention_hours} HOURS" + logger.info(f"Running vacuum with command {command}") + client.conn.sql(command) + logger.info(f"Vacuum successful for table {table}") + + @staticmethod + def optimize( + client: SparkClient, + table: str = None, + z_order: list = None, + date_column: str = "timestamp", + from_date: str = None, + auto_compact: bool = False, + optimize_write: bool = False, + ): + """Optimize a Delta table. + + For auto-compaction and optimize write DBR >= 14.3 LTS + and Delta >= 3.1.0 are MANDATORY. + For z-ordering DBR >= 13.3 LTS and Delta >= 2.0.0 are MANDATORY. + Auto-compaction (recommended) reduces the small file problem + (overhead due to lots of metadata). + Z-order by columns that is commonly used in queries + predicates and has a high cardinality. + https://docs.delta.io/latest/optimizations-oss.html + """ + + if auto_compact: + client.conf.set("spark.databricks.delta.autoCompact.enabled", "true") + + if optimize_write: + client.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true") + + if table: + command = f"OPTIMIZE {table}" + + if from_date: + command += f"WHERE {date_column} >= {from_date}" + + if z_order: + command += f" ZORDER BY {','.join(z_order)}" + + logger.info(f"Running optimize with command {command}...") + client.conn.sql(command) + logger.info(f"Optimize successful for table {table}.") diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py index c01fee1d..99bfe66a 100644 --- a/butterfree/load/writers/historical_feature_store_writer.py +++ b/butterfree/load/writers/historical_feature_store_writer.py @@ -14,6 +14,7 @@ from butterfree.dataframe_service import repartition_df from butterfree.hooks import Hook from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook +from butterfree.load.writers.delta_writer import DeltaWriter from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet @@ -92,6 +93,15 @@ class HistoricalFeatureStoreWriter(Writer): improve queries performance. The data is stored in partition folders in AWS S3 based on time (per year, month and day). + >>> spark_client = SparkClient() + >>> writer = HistoricalFeatureStoreWriter() + >>> writer.write(feature_set=feature_set, + ... dataframe=dataframe, + ... spark_client=spark_client + ... merge_on=["id", "timestamp"]) + + This procedure will skip dataframe write and will activate Delta Merge. + Use it when the table already exist. """ PARTITION_BY = [ @@ -114,6 +124,7 @@ def __init__( interval_mode: bool = False, check_schema_hook: Optional[Hook] = None, row_count_validation: bool = True, + merge_on: list = None, ): super(HistoricalFeatureStoreWriter, self).__init__( db_config or MetastoreConfig(), @@ -121,6 +132,7 @@ def __init__( interval_mode, False, row_count_validation, + merge_on, ) self.database = database or environment.get_variable( "FEATURE_STORE_HISTORICAL_DATABASE" @@ -141,6 +153,7 @@ def write( feature_set: object processed with feature_set informations. dataframe: spark dataframe containing data from a feature set. spark_client: client for spark connections with external services. + merge_on: when filled, the writing is an upsert in a Delta table. If the debug_mode is set to True, a temporary table with a name in the format: historical_feature_store__{feature_set.name} will be created instead of writing @@ -174,13 +187,22 @@ def write( s3_key = os.path.join("historical", feature_set.entity, feature_set.name) - spark_client.write_table( - dataframe=dataframe, - database=self.database, - table_name=feature_set.name, - partition_by=self.PARTITION_BY, - **self.db_config.get_options(s3_key), - ) + if self.merge_on: + DeltaWriter.merge( + client=spark_client, + database=self.database, + table=feature_set.name, + merge_on=self.merge_on, + source_df=dataframe, + ) + else: + spark_client.write_table( + dataframe=dataframe, + database=self.database, + table_name=feature_set.name, + partition_by=self.PARTITION_BY, + **self.db_config.get_options(s3_key), + ) def _assert_validation_count( self, table_name: str, written_count: int, dataframe_count: int diff --git a/butterfree/load/writers/writer.py b/butterfree/load/writers/writer.py index 780b9ec2..a99514ae 100644 --- a/butterfree/load/writers/writer.py +++ b/butterfree/load/writers/writer.py @@ -27,6 +27,7 @@ def __init__( interval_mode: Optional[bool] = False, write_to_entity: Optional[bool] = False, row_count_validation: Optional[bool] = True, + merge_on: Optional[list] = None, ) -> None: super().__init__() self.db_config = db_config @@ -35,6 +36,7 @@ def __init__( self.interval_mode = interval_mode self.write_to_entity = write_to_entity self.row_count_validation = row_count_validation + self.merge_on = merge_on def with_( self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any diff --git a/docs/source/butterfree.automated.rst b/docs/source/butterfree.automated.rst index de290d9c..9c01ac54 100644 --- a/docs/source/butterfree.automated.rst +++ b/docs/source/butterfree.automated.rst @@ -4,6 +4,8 @@ butterfree.automated package Submodules ---------- +butterfree.automated.feature\_set\_creation module +-------------------------------------------------- .. automodule:: butterfree.automated.feature_set_creation :members: diff --git a/docs/source/butterfree.constants.rst b/docs/source/butterfree.constants.rst index bd721330..de6f1cee 100644 --- a/docs/source/butterfree.constants.rst +++ b/docs/source/butterfree.constants.rst @@ -54,7 +54,6 @@ butterfree.constants.spark\_constants module :undoc-members: :show-inheritance: - .. automodule:: butterfree.constants.spark_constants :members: :undoc-members: diff --git a/docs/source/butterfree.dataframe_service.rst b/docs/source/butterfree.dataframe_service.rst index 4343305b..faf9cf54 100644 --- a/docs/source/butterfree.dataframe_service.rst +++ b/docs/source/butterfree.dataframe_service.rst @@ -4,18 +4,29 @@ butterfree.dataframe\_service package Submodules ---------- +butterfree.dataframe\_service.incremental\_strategy module +---------------------------------------------------------- .. automodule:: butterfree.dataframe_service.incremental_strategy :members: :undoc-members: :show-inheritance: +butterfree.dataframe\_service.partitioning module +------------------------------------------------- .. automodule:: butterfree.dataframe_service.partitioning :members: :undoc-members: :show-inheritance: +butterfree.dataframe\_service.repartition module +------------------------------------------------ + +.. automodule:: butterfree.dataframe_service.repartition + :members: + :undoc-members: + :show-inheritance: .. automodule:: butterfree.dataframe_service.repartition :members: diff --git a/docs/source/butterfree.hooks.rst b/docs/source/butterfree.hooks.rst index 72f13223..c633cade 100644 --- a/docs/source/butterfree.hooks.rst +++ b/docs/source/butterfree.hooks.rst @@ -12,12 +12,16 @@ Subpackages Submodules ---------- +butterfree.hooks.hook module +---------------------------- .. automodule:: butterfree.hooks.hook :members: :undoc-members: :show-inheritance: +butterfree.hooks.hookable\_component module +------------------------------------------- .. automodule:: butterfree.hooks.hookable_component :members: diff --git a/docs/source/butterfree.hooks.schema_compatibility.rst b/docs/source/butterfree.hooks.schema_compatibility.rst index a39c5b93..2d3de66c 100644 --- a/docs/source/butterfree.hooks.schema_compatibility.rst +++ b/docs/source/butterfree.hooks.schema_compatibility.rst @@ -4,12 +4,16 @@ butterfree.hooks.schema\_compatibility package Submodules ---------- +butterfree.hooks.schema\_compatibility.cassandra\_table\_schema\_compatibility\_hook module +------------------------------------------------------------------------------------------- .. automodule:: butterfree.hooks.schema_compatibility.cassandra_table_schema_compatibility_hook :members: :undoc-members: :show-inheritance: +butterfree.hooks.schema\_compatibility.spark\_table\_schema\_compatibility\_hook module +--------------------------------------------------------------------------------------- .. automodule:: butterfree.hooks.schema_compatibility.spark_table_schema_compatibility_hook :members: diff --git a/docs/source/butterfree.load.writers.rst b/docs/source/butterfree.load.writers.rst index 2a173c9a..b20eb85e 100644 --- a/docs/source/butterfree.load.writers.rst +++ b/docs/source/butterfree.load.writers.rst @@ -4,6 +4,14 @@ butterfree.load.writers package Submodules ---------- +butterfree.load.writers.delta\_writer module +-------------------------------------------- + +.. automodule:: butterfree.load.writers.delta_writer + :members: + :undoc-members: + :show-inheritance: + butterfree.load.writers.historical\_feature\_store\_writer module ----------------------------------------------------------------- diff --git a/docs/source/butterfree.migrations.database_migration.rst b/docs/source/butterfree.migrations.database_migration.rst index 892165df..32ba4d4d 100644 --- a/docs/source/butterfree.migrations.database_migration.rst +++ b/docs/source/butterfree.migrations.database_migration.rst @@ -4,18 +4,24 @@ butterfree.migrations.database\_migration package Submodules ---------- +butterfree.migrations.database\_migration.cassandra\_migration module +--------------------------------------------------------------------- .. automodule:: butterfree.migrations.database_migration.cassandra_migration :members: :undoc-members: :show-inheritance: +butterfree.migrations.database\_migration.database\_migration module +-------------------------------------------------------------------- .. automodule:: butterfree.migrations.database_migration.database_migration :members: :undoc-members: :show-inheritance: +butterfree.migrations.database\_migration.metastore\_migration module +--------------------------------------------------------------------- .. automodule:: butterfree.migrations.database_migration.metastore_migration :members: diff --git a/requirements.txt b/requirements.txt index f3968c60..9c9eea64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ typer==0.4.2 typing-extensions>3.7.4,<5 boto3==1.17.* numpy==1.26.4 +delta-spark==3.2.0 diff --git a/setup.py b/setup.py index bc4f0b45..e6b9f761 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import find_packages, setup __package_name__ = "butterfree" -__version__ = "1.3.5" +__version__ = "1.4.0" __repository_url__ = "https://github.com/quintoandar/butterfree" with open("requirements.txt") as f: diff --git a/tests/unit/butterfree/load/writers/test_delta_writer.py b/tests/unit/butterfree/load/writers/test_delta_writer.py new file mode 100644 index 00000000..550f6d05 --- /dev/null +++ b/tests/unit/butterfree/load/writers/test_delta_writer.py @@ -0,0 +1,83 @@ +import os +from unittest import mock + +import pytest + +from butterfree.clients import SparkClient +from butterfree.load.writers import DeltaWriter + +DELTA_LOCATION = "spark-warehouse" + + +class TestDeltaWriter: + + def __checkFileExists(self, file_name: str = "test_delta_table") -> bool: + return os.path.exists(os.path.join(DELTA_LOCATION, file_name)) + + @pytest.fixture + def merge_builder_mock(self): + builder = mock.MagicMock() + builder.whenMatchedDelete.return_value = builder + builder.whenMatchedUpdateAll.return_value = builder + builder.whenNotMatchedInsertAll.return_value = builder + return builder + + def test_merge(self, feature_set_dataframe, merge_builder_mock): + + client = SparkClient() + delta_writer = DeltaWriter() + delta_writer.merge = mock.MagicMock() + + DeltaWriter().merge( + client=client, + database=None, + table="test_delta_table", + merge_on=["id"], + source_df=feature_set_dataframe, + ) + + assert merge_builder_mock.execute.assert_called_once + + # Step 2 + source = client.conn.createDataFrame( + [(1, "test3"), (2, "test4"), (3, "test5")], ["id", "feature"] + ) + + DeltaWriter().merge( + client=client, + database=None, + table="test_delta_table", + merge_on=["id"], + source_df=source, + when_not_matched_insert_condition=None, + when_matched_update_condition="id > 2", + ) + + assert merge_builder_mock.execute.assert_called_once + + def test_optimize(self, mocker): + + client = SparkClient() + conn_mock = mocker.patch( + "butterfree.clients.SparkClient.conn", return_value=mock.Mock() + ) + dw = DeltaWriter() + + dw.optimize = mock.MagicMock(client) + dw.optimize(client, "a_table") + + conn_mock.assert_called_once + + def test_vacuum(self, mocker): + + client = SparkClient() + conn_mock = mocker.patch( + "butterfree.clients.SparkClient.conn", return_value=mock.Mock() + ) + dw = DeltaWriter() + retention_hours = 24 + dw.vacuum = mock.MagicMock(client) + + dw.vacuum("a_table", retention_hours, client) + + conn_mock.assert_called_once diff --git a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py index 9e84aacd..d9d9181a 100644 --- a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py +++ b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py @@ -1,5 +1,6 @@ import datetime import random +from unittest import mock import pytest from pyspark.sql.functions import spark_partition_id @@ -145,6 +146,30 @@ def test_write_in_debug_mode_with_interval_mode( # then assert_dataframe_equality(historical_feature_set_dataframe, result_df) + def test_merge_from_historical_writer( + self, + feature_set, + feature_set_dataframe, + mocker, + ): + # given + spark_client = SparkClient() + + spark_client.write_table = mocker.stub("write_table") + writer = HistoricalFeatureStoreWriter(merge_on=["id", "timestamp"]) + + static_mock = mocker.patch( + "butterfree.load.writers.DeltaWriter.merge", return_value=mock.Mock() + ) + + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + + assert static_mock.call_count == 1 + def test_validate(self, historical_feature_set_dataframe, mocker, feature_set): # given spark_client = mocker.stub("spark_client") diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py index fcf60132..c0ebb47a 100644 --- a/tests/unit/butterfree/transform/conftest.py +++ b/tests/unit/butterfree/transform/conftest.py @@ -1,7 +1,10 @@ import json from unittest.mock import Mock +import pyspark.pandas as ps from pyspark.sql import functions +from pyspark.sql.functions import col +from pyspark.sql.types import TimestampType from pytest import fixture from butterfree.constants import DataType @@ -16,6 +19,83 @@ from butterfree.transform.utils import Function +def create_dataframe(data, timestamp_col="ts"): + pdf = ps.DataFrame.from_dict(data) + df = pdf.to_spark() + df = df.withColumn( + TIMESTAMP_COLUMN, df[timestamp_col].cast(DataType.TIMESTAMP.spark) + ) + return df + + +def create_dataframe_from_data( + spark_context, spark_session, data, timestamp_col="timestamp", use_json=False +): + if use_json: + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + else: + df = create_dataframe(data, timestamp_col=timestamp_col) + + df = df.withColumn(timestamp_col, col(timestamp_col).cast(TimestampType())) + return df + + +def create_rolling_windows_agg_dataframe( + spark_context, spark_session, data, timestamp_col="timestamp", use_json=False +): + if use_json: + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + df = df.withColumn( + timestamp_col, col(timestamp_col).cast(DataType.TIMESTAMP.spark) + ) + else: + df = create_dataframe(data, timestamp_col=timestamp_col) + + return df + + +def build_data(rows, base_features, dynamic_features=None): + """ + Constrói uma lista de dicionários para DataFrame com recursos dinâmicos. + + :param rows: Lista de tuplas com (id, timestamp, base_values, dynamic_values). + :param base_features: Lista de nomes de recursos base (strings). + :param dynamic_features: Lista de nomes de recursos dinâmicos, + mapeando para o índice de dynamic_values (opcional). + :return: Lista de dicionários para criação do DataFrame. + """ + data = [] + for row in rows: + id_value, timestamp_value, base_values, dynamic_values = row + + entry = { + "id": id_value, + "timestamp": timestamp_value, + } + + # Adiciona valores das features base + entry.update( + {feature: value for feature, value in zip(base_features, base_values)} + ) + + # Adiciona valores das features dinâmicas, se houver + if dynamic_features: + entry.update( + { + feature: dynamic_values[idx] + for idx, feature in enumerate(dynamic_features) + } + ) + + data.append(entry) + + return data + + def make_dataframe(spark_context, spark_session): data = [ { @@ -54,10 +134,7 @@ def make_dataframe(spark_context, spark_session): "nonfeature": 0, }, ] - df = spark_session.read.json(spark_context.parallelize(data, 1)) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_filtering_dataframe(spark_context, spark_session): @@ -70,12 +147,7 @@ def make_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, {"id": 1, "ts": 7, "feature1": None, "feature2": None, "feature3": None}, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_output_filtering_dataframe(spark_context, spark_session): @@ -86,131 +158,94 @@ def make_output_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 4, "feature1": 0, "feature2": 1, "feature3": 1}, {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_rolling_windows_agg_dataframe(spark_context, spark_session): - data = [ - { - "id": 1, - "timestamp": "2016-04-11 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, - { - "id": 1, - "timestamp": "2016-04-12 00:00:00", - "feature1__avg_over_1_week_rolling_windows": 300.0, - "feature2__avg_over_1_week_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-19 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, - { - "id": 1, - "timestamp": "2016-04-23 00:00:00", - "feature1__avg_over_1_week_rolling_windows": 1000.0, - "feature2__avg_over_1_week_rolling_windows": 1100.0, - }, - { - "id": 1, - "timestamp": "2016-04-30 00:00:00", - "feature1__avg_over_1_week_rolling_windows": None, - "feature2__avg_over_1_week_rolling_windows": None, - }, + rows = [ + (1, "2016-04-11 00:00:00", [None, None], None), + (1, "2016-04-12 00:00:00", [300.0, 350.0], None), + (1, "2016-04-19 00:00:00", [None, None], None), + (1, "2016-04-23 00:00:00", [1000.0, 1100.0], None), + (1, "2016-04-30 00:00:00", [None, None], None), ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - return df + base_features = [ + "feature1__avg_over_1_week_rolling_windows", + "feature2__avg_over_1_week_rolling_windows", + ] + + data = build_data(rows, base_features) + return create_dataframe_from_data(spark_context, spark_session, data) def make_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): - data = [ - { - "id": 1, - "timestamp": "2016-04-11 12:00:00", - "feature1__avg_over_1_day_rolling_windows": 266.6666666666667, - "feature2__avg_over_1_day_rolling_windows": 300.0, - }, - { - "id": 1, - "timestamp": "2016-04-12 00:00:00", - "feature1__avg_over_1_day_rolling_windows": 300.0, - "feature2__avg_over_1_day_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-12 12:00:00", - "feature1__avg_over_1_day_rolling_windows": 400.0, - "feature2__avg_over_1_day_rolling_windows": 500.0, - }, + rows = [ + (1, "2016-04-11 12:00:00", [266.6666666666667, 300.0], None), + (1, "2016-04-12 00:00:00", [300.0, 350.0], None), + (1, "2016-04-12 12:00:00", [400.0, 500.0], None), ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - return df + base_features = [ + "feature1__avg_over_1_day_rolling_windows", + "feature2__avg_over_1_day_rolling_windows", + ] + + data = build_data(rows, base_features) + return create_dataframe_from_data(spark_context, spark_session, data) def make_multiple_rolling_windows_hour_slide_agg_dataframe( spark_context, spark_session ): - data = [ - { - "id": 1, - "timestamp": "2016-04-11 12:00:00", - "feature1__avg_over_2_days_rolling_windows": 266.6666666666667, - "feature1__avg_over_3_days_rolling_windows": 266.6666666666667, - "feature2__avg_over_2_days_rolling_windows": 300.0, - "feature2__avg_over_3_days_rolling_windows": 300.0, - }, - { - "id": 1, - "timestamp": "2016-04-12 00:00:00", - "feature1__avg_over_2_days_rolling_windows": 300.0, - "feature1__avg_over_3_days_rolling_windows": 300.0, - "feature2__avg_over_2_days_rolling_windows": 350.0, - "feature2__avg_over_3_days_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-13 12:00:00", - "feature1__avg_over_2_days_rolling_windows": 400.0, - "feature1__avg_over_3_days_rolling_windows": 300.0, - "feature2__avg_over_2_days_rolling_windows": 500.0, - "feature2__avg_over_3_days_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-14 00:00:00", - "feature1__avg_over_3_days_rolling_windows": 300.0, - "feature2__avg_over_3_days_rolling_windows": 350.0, - }, - { - "id": 1, - "timestamp": "2016-04-14 12:00:00", - "feature1__avg_over_3_days_rolling_windows": 400.0, - "feature2__avg_over_3_days_rolling_windows": 500.0, - }, + rows = [ + ( + 1, + "2016-04-11 12:00:00", + [], + [266.6666666666667, 266.6666666666667, 300.0, 300.0], + ), + (1, "2016-04-12 00:00:00", [], [300.0, 300.0, 350.0, 350.0]), + (1, "2016-04-13 12:00:00", [], [400.0, 300.0, 500.0, 350.0]), + (1, "2016-04-14 00:00:00", [], [None, 300.0, None, 350.0]), + (1, "2016-04-14 12:00:00", [], [None, 400.0, None, 500.0]), ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - return df + dynamic_features = [ + "feature1__avg_over_2_days_rolling_windows", + "feature1__avg_over_3_days_rolling_windows", + "feature2__avg_over_2_days_rolling_windows", + "feature2__avg_over_3_days_rolling_windows", + ] + + data = build_data(rows, [], dynamic_features=dynamic_features) + return create_dataframe_from_data(spark_context, spark_session, data, use_json=True) + + +def create_rolling_window_dataframe( + spark_context, spark_session, rows, base_features, dynamic_features=None +): + """ + Cria um DataFrame com recursos de rolagem de janelas agregadas. + + :param spark_context: Contexto do Spark. + :param spark_session: Sessão do Spark. + :param rows: Lista de tuplas com (id, timestamp, base_values, dynamic_values). + :param base_features: Lista de nomes de recursos base (strings). + :param dynamic_features: Lista de nomes de recursos dinâmicos, + mapeando para o índice de dynamic_values (opcional). + :return: DataFrame do Spark. + """ + data = build_data(rows, base_features, dynamic_features) + + # Converte a lista de dicionários em um RDD do Spark + rdd = spark_context.parallelize(data).map(lambda x: json.dumps(x)) + + # Cria o DataFrame do Spark a partir do RDD + df = spark_session.read.json(rdd) + + # Converte a coluna "timestamp" para o tipo TIMESTAMP + df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) def make_fs(spark_context, spark_session): @@ -257,8 +292,7 @@ def make_fs_dataframe_with_distinct(spark_context, spark_session): "h3": "86a8100efffffff", }, ] - df = spark_session.read.json(spark_context.parallelize(data, 1)) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df @@ -286,10 +320,7 @@ def make_target_df_distinct(spark_context, spark_session): "feature__sum_over_3_days_rolling_windows": None, }, ] - df = spark_session.read.json( - spark_context.parallelize(data).map(lambda x: json.dumps(x)) - ) - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df