diff --git a/syncmaster/dto/transfers.py b/syncmaster/dto/transfers.py index 1d133316..388afd78 100644 --- a/syncmaster/dto/transfers.py +++ b/syncmaster/dto/transfers.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import ClassVar -from onetl.file.format import CSV, JSON, JSONLine +from onetl.file.format import CSV, JSON, Excel, JSONLine @dataclass @@ -20,10 +20,17 @@ class DBTransferDTO(TransferDTO): @dataclass class FileTransferDTO(TransferDTO): directory_path: str - file_format: CSV | JSONLine | JSON + file_format: CSV | JSONLine | JSON | Excel options: dict df_schema: dict | None = None + _format_parsers = { + "csv": CSV, + "jsonline": JSONLine, + "json": JSON, + "excel": Excel, + } + def __post_init__(self): if isinstance(self.file_format, dict): self.file_format = self._get_format(self.file_format.copy()) @@ -32,13 +39,10 @@ def __post_init__(self): def _get_format(self, file_format: dict): file_type = file_format.pop("type", None) - if file_type == "csv": - return CSV.parse_obj(file_format) - if file_type == "jsonline": - return JSONLine.parse_obj(file_format) - if file_type == "json": - return JSON.parse_obj(file_format) - raise ValueError("Unknown file type") + parser_class = self._format_parsers.get(file_type) + if parser_class is not None: + return parser_class.parse_obj(file_format) + raise ValueError(f"Unknown file type: {file_type}") @dataclass diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index 2e48f443..38696e1f 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -6,12 +6,13 @@ from typing import TYPE_CHECKING from onetl.connection import SparkS3 +from onetl.file import FileDFReader from syncmaster.dto.connections import S3ConnectionDTO from syncmaster.worker.handlers.file.base import FileHandler if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import DataFrame, SparkSession class S3Handler(FileHandler): @@ -29,3 +30,20 @@ def connect(self, spark: SparkSession): extra=self.connection_dto.additional_params, spark=spark, ).check() + + def read(self) -> DataFrame: + from pyspark.sql.types import StructType + + options = {} + if self.transfer_dto.file_format.__class__.__name__ == "Excel": + options = {"inferSchema": True} + + reader = FileDFReader( + connection=self.connection, + format=self.transfer_dto.file_format, + source_path=self.transfer_dto.directory_path, + df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None, + options={**options, **self.transfer_dto.options}, + ) + + return reader.run() diff --git a/syncmaster/worker/spark.py b/syncmaster/worker/spark.py index d4b649a6..61457053 100644 --- a/syncmaster/worker/spark.py +++ b/syncmaster/worker/spark.py @@ -37,6 +37,7 @@ def get_worker_spark_session( def get_packages(db_type: str) -> list[str]: from onetl.connection import MSSQL, Clickhouse, MySQL, Oracle, Postgres, SparkS3 + from onetl.file.format import Excel if db_type == "postgres": return Postgres.get_packages() @@ -53,7 +54,11 @@ def get_packages(db_type: str) -> list[str]: import pyspark spark_version = pyspark.__version__ - return SparkS3.get_packages(spark_version=spark_version) + # see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel + return SparkS3.get_packages(spark_version=spark_version) + Excel.get_packages(spark_version="3.5.1") + if db_type == "hdfs": + # see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel + return Excel.get_packages(spark_version="3.5.1") # If the database type does not require downloading .jar packages return [] diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 6758986c..38e2084f 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -3,14 +3,15 @@ import os import secrets from collections import namedtuple -from pathlib import Path, PurePosixPath +from pathlib import Path, PosixPath, PurePosixPath import pyspark import pytest import pytest_asyncio from onetl.connection import MSSQL, Clickhouse, Hive, MySQL, Oracle, Postgres, SparkS3 +from onetl.connection.file_connection.s3 import S3 from onetl.db import DBWriter -from onetl.file.format import CSV, JSON, JSONLine +from onetl.file.format import CSV, JSON, Excel, JSONLine from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import ( DateType, @@ -112,6 +113,10 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession: ) ) + if "hdfs" in markers or "s3" in markers: + # see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel + maven_packages.extend(Excel.get_packages(spark_version="3.5.1")) + if maven_packages: spark = spark.config("spark.jars.packages", ",".join(maven_packages)) @@ -462,12 +467,22 @@ def s3_file_df_connection(s3_file_connection, spark, s3_server): @pytest.fixture(scope="session") -def prepare_s3(resource_path, s3_file_connection, s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath]): - logger.info("START PREPARE HDFS") - connection, upload_to = s3_file_df_connection_with_path - files = upload_files(resource_path, upload_to, s3_file_connection) - logger.info("END PREPARE HDFS") - return connection, upload_to, files +def prepare_s3( + resource_path: PosixPath, + s3_file_connection: S3, + s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath], +): + logger.info("START PREPARE S3") + connection, remote_path = s3_file_df_connection_with_path + + s3_file_connection.remove_dir(remote_path, recursive=True) + files = upload_files(resource_path, remote_path, s3_file_connection) + + yield connection, remote_path, files + + logger.info("START POST-CLEANUP S3") + s3_file_connection.remove_dir(remote_path, recursive=True) + logger.info("END POST-CLEANUP S3") @pytest.fixture(scope="session") @@ -635,14 +650,14 @@ def prepare_clickhouse( pass def fill_with_data(df: DataFrame): - logger.info("START PREPARE ORACLE") + logger.info("START PREPARE CLICKHOUSE") db_writer = DBWriter( connection=onetl_conn, target=f"{clickhouse.user}.source_table", options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"), ) db_writer.run(df) - logger.info("END PREPARE ORACLE") + logger.info("END PREPARE CLICKHOUSE") yield onetl_conn, fill_with_data @@ -745,7 +760,51 @@ def fill_with_data(df: DataFrame): pass -@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})]) +@pytest.fixture +def prepare_mysql( + mysql_for_conftest: MySQLConnectionDTO, + spark: SparkSession, +): + mysql = mysql_for_conftest + onetl_conn = MySQL( + host=mysql.host, + port=mysql.port, + user=mysql.user, + password=mysql.password, + database=mysql.database_name, + spark=spark, + ).check() + try: + onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table") + except Exception: + pass + try: + onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table") + except Exception: + pass + + def fill_with_data(df: DataFrame): + logger.info("START PREPARE MYSQL") + db_writer = DBWriter( + connection=onetl_conn, + target=f"{mysql.database_name}.source_table", + ) + db_writer.run(df) + logger.info("END PREPARE MYSQL") + + yield onetl_conn, fill_with_data + + try: + onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table") + except Exception: + pass + try: + onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table") + except Exception: + pass + + +@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {}), ("excel", {})]) def source_file_format(request: FixtureRequest): name, params = request.param if name == "csv": @@ -769,10 +828,17 @@ def source_file_format(request: FixtureRequest): **params, ) + if name == "excel": + return "excel", Excel( + header=True, + inferSchema=True, + **params, + ) + raise ValueError(f"Unsupported file format: {name}") -@pytest.fixture(params=[("csv", {}), ("jsonline", {})]) +@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("excel", {})]) def target_file_format(request: FixtureRequest): name, params = request.param if name == "csv": @@ -791,6 +857,12 @@ def target_file_format(request: FixtureRequest): **params, ) + if name == "excel": + return "excel", Excel( + header=False, + **params, + ) + raise ValueError(f"Unsupported file format: {name}") diff --git a/tests/test_integration/test_run_transfer/test_hdfs.py b/tests/test_integration/test_run_transfer/test_hdfs.py index 3703b366..fb906142 100644 --- a/tests/test_integration/test_run_transfer/test_hdfs.py +++ b/tests/test_integration/test_run_transfer/test_hdfs.py @@ -8,6 +8,7 @@ from onetl.db import DBReader from onetl.file import FileDFReader from pyspark.sql import DataFrame +from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp from pytest import FixtureRequest from sqlalchemy.ext.asyncio import AsyncSession @@ -37,6 +38,7 @@ async def hdfs_to_postgres( file_format_flavor: str, ): format_name, file_format = source_file_format + format_name_in_path = "xlsx" if format_name == "excel" else format_name _, source_path, _ = prepare_hdfs result = await create_transfer( @@ -47,7 +49,7 @@ async def hdfs_to_postgres( target_connection_id=postgres_connection.id, source_params={ "type": "hdfs", - "directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor), + "directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor), "file_format": { "type": format_name, **file_format.dict(), @@ -121,6 +123,11 @@ async def postgres_to_hdfs( "without_compression", id="jsonline", ), + pytest.param( + ("excel", {}), + "with_header", + id="excel", + ), ], indirect=["source_file_format", "file_format_flavor"], ) @@ -135,6 +142,7 @@ async def test_run_transfer_hdfs_to_postgres( ): # Arrange postgres, _ = prepare_postgres + file_format, _ = source_file_format # Act result = await client.post( @@ -164,6 +172,12 @@ async def test_run_transfer_hdfs_to_postgres( table="public.target_table", ) df = reader.run() + + # as Excel does not support datetime values with precision greater than milliseconds + if file_format == "excel": + df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) @@ -183,6 +197,11 @@ async def test_run_transfer_hdfs_to_postgres( "without_compression", id="jsonline", ), + pytest.param( + ("excel", {}), + "with_header", + id="excel", + ), ], indirect=["target_file_format", "file_format_flavor"], ) @@ -235,6 +254,13 @@ async def test_run_transfer_postgres_to_hdfs( ) df = reader.run() + # as Excel does not support datetime values with precision greater than milliseconds + if format_name == "excel": + init_df = init_df.withColumn( + "REGISTERED_AT", + to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")), + ) + for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) diff --git a/tests/test_integration/test_run_transfer/test_mssql.py b/tests/test_integration/test_run_transfer/test_mssql.py index 3e8049e2..93ef3895 100644 --- a/tests/test_integration/test_run_transfer/test_mssql.py +++ b/tests/test_integration/test_run_transfer/test_mssql.py @@ -118,7 +118,7 @@ async def test_run_transfer_postgres_to_mssql( ) df = reader.run() - # as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 + # as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) @@ -173,7 +173,7 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming( assert df.columns != init_df_with_mixed_column_naming.columns assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] - # as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 + # as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 df = df.withColumn("Registered At", date_trunc("second", col("Registered At"))) init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( "Registered At", @@ -228,7 +228,7 @@ async def test_run_transfer_mssql_to_postgres( ) df = reader.run() - # as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 + # as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) @@ -283,7 +283,7 @@ async def test_run_transfer_mssql_to_postgres_mixed_naming( assert df.columns != init_df_with_mixed_column_naming.columns assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] - # as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 + # as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5 df = df.withColumn("Registered At", date_trunc("second", col("Registered At"))) init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( "Registered At", diff --git a/tests/test_integration/test_run_transfer/test_s3.py b/tests/test_integration/test_run_transfer/test_s3.py index 6187259b..3eea28a8 100644 --- a/tests/test_integration/test_run_transfer/test_s3.py +++ b/tests/test_integration/test_run_transfer/test_s3.py @@ -8,6 +8,7 @@ from onetl.db import DBReader from onetl.file import FileDFReader from pyspark.sql import DataFrame +from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp from pytest import FixtureRequest from sqlalchemy.ext.asyncio import AsyncSession @@ -37,6 +38,7 @@ async def s3_to_postgres( file_format_flavor: str, ): format_name, file_format = source_file_format + format_name_in_path = "xlsx" if format_name == "excel" else format_name _, source_path, _ = prepare_s3 result = await create_transfer( @@ -47,7 +49,7 @@ async def s3_to_postgres( target_connection_id=postgres_connection.id, source_params={ "type": "s3", - "directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor), + "directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor), "file_format": { "type": format_name, **file_format.dict(), @@ -121,6 +123,11 @@ async def postgres_to_s3( "without_compression", id="jsonline", ), + pytest.param( + ("excel", {}), + "with_header", + id="excel", + ), ], indirect=["source_file_format", "file_format_flavor"], ) @@ -135,6 +142,7 @@ async def test_run_transfer_s3_to_postgres( ): # Arrange postgres, _ = prepare_postgres + file_format, _ = source_file_format # Act result = await client.post( @@ -165,6 +173,12 @@ async def test_run_transfer_s3_to_postgres( table="public.target_table", ) df = reader.run() + + # as Excel does not support datetime values with precision greater than milliseconds + if file_format == "excel": + df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) @@ -184,6 +198,11 @@ async def test_run_transfer_s3_to_postgres( "without_compression", id="jsonline", ), + pytest.param( + ("excel", {}), + "with_header", + id="excel", + ), ], indirect=["target_file_format", "file_format_flavor"], ) @@ -193,6 +212,7 @@ async def test_run_transfer_postgres_to_s3( client: AsyncClient, s3_file_df_connection: SparkS3, prepare_postgres, + prepare_s3, postgres_to_s3: Connection, target_file_format, file_format_flavor: str, @@ -235,6 +255,13 @@ async def test_run_transfer_postgres_to_s3( ) df = reader.run() + # as Excel does not support datetime values with precision greater than milliseconds + if format_name == "excel": + init_df = init_df.withColumn( + "REGISTERED_AT", + to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")), + ) + for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType))