Skip to content

Commit

Permalink
Add Excel integration tests (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasDevelopment authored Dec 2, 2024
1 parent 38fc8ad commit 525d635
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 29 deletions.
22 changes: 13 additions & 9 deletions syncmaster/dto/transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import ClassVar

from onetl.file.format import CSV, JSON, JSONLine
from onetl.file.format import CSV, JSON, Excel, JSONLine


@dataclass
Expand All @@ -20,10 +20,17 @@ class DBTransferDTO(TransferDTO):
@dataclass
class FileTransferDTO(TransferDTO):
directory_path: str
file_format: CSV | JSONLine | JSON
file_format: CSV | JSONLine | JSON | Excel
options: dict
df_schema: dict | None = None

_format_parsers = {
"csv": CSV,
"jsonline": JSONLine,
"json": JSON,
"excel": Excel,
}

def __post_init__(self):
if isinstance(self.file_format, dict):
self.file_format = self._get_format(self.file_format.copy())
Expand All @@ -32,13 +39,10 @@ def __post_init__(self):

def _get_format(self, file_format: dict):
file_type = file_format.pop("type", None)
if file_type == "csv":
return CSV.parse_obj(file_format)
if file_type == "jsonline":
return JSONLine.parse_obj(file_format)
if file_type == "json":
return JSON.parse_obj(file_format)
raise ValueError("Unknown file type")
parser_class = self._format_parsers.get(file_type)
if parser_class is not None:
return parser_class.parse_obj(file_format)
raise ValueError(f"Unknown file type: {file_type}")


@dataclass
Expand Down
20 changes: 19 additions & 1 deletion syncmaster/worker/handlers/file/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from typing import TYPE_CHECKING

from onetl.connection import SparkS3
from onetl.file import FileDFReader

from syncmaster.dto.connections import S3ConnectionDTO
from syncmaster.worker.handlers.file.base import FileHandler

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession


class S3Handler(FileHandler):
Expand All @@ -29,3 +30,20 @@ def connect(self, spark: SparkSession):
extra=self.connection_dto.additional_params,
spark=spark,
).check()

def read(self) -> DataFrame:
from pyspark.sql.types import StructType

options = {}
if self.transfer_dto.file_format.__class__.__name__ == "Excel":
options = {"inferSchema": True}

reader = FileDFReader(
connection=self.connection,
format=self.transfer_dto.file_format,
source_path=self.transfer_dto.directory_path,
df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None,
options={**options, **self.transfer_dto.options},
)

return reader.run()
7 changes: 6 additions & 1 deletion syncmaster/worker/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_worker_spark_session(

def get_packages(db_type: str) -> list[str]:
from onetl.connection import MSSQL, Clickhouse, MySQL, Oracle, Postgres, SparkS3
from onetl.file.format import Excel

if db_type == "postgres":
return Postgres.get_packages()
Expand All @@ -53,7 +54,11 @@ def get_packages(db_type: str) -> list[str]:
import pyspark

spark_version = pyspark.__version__
return SparkS3.get_packages(spark_version=spark_version)
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
return SparkS3.get_packages(spark_version=spark_version) + Excel.get_packages(spark_version="3.5.1")
if db_type == "hdfs":
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
return Excel.get_packages(spark_version="3.5.1")

# If the database type does not require downloading .jar packages
return []
Expand Down
96 changes: 84 additions & 12 deletions tests/test_integration/test_run_transfer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import os
import secrets
from collections import namedtuple
from pathlib import Path, PurePosixPath
from pathlib import Path, PosixPath, PurePosixPath

import pyspark
import pytest
import pytest_asyncio
from onetl.connection import MSSQL, Clickhouse, Hive, MySQL, Oracle, Postgres, SparkS3
from onetl.connection.file_connection.s3 import S3
from onetl.db import DBWriter
from onetl.file.format import CSV, JSON, JSONLine
from onetl.file.format import CSV, JSON, Excel, JSONLine
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import (
DateType,
Expand Down Expand Up @@ -112,6 +113,10 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
)
)

if "hdfs" in markers or "s3" in markers:
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
maven_packages.extend(Excel.get_packages(spark_version="3.5.1"))

if maven_packages:
spark = spark.config("spark.jars.packages", ",".join(maven_packages))

Expand Down Expand Up @@ -462,12 +467,22 @@ def s3_file_df_connection(s3_file_connection, spark, s3_server):


@pytest.fixture(scope="session")
def prepare_s3(resource_path, s3_file_connection, s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath]):
logger.info("START PREPARE HDFS")
connection, upload_to = s3_file_df_connection_with_path
files = upload_files(resource_path, upload_to, s3_file_connection)
logger.info("END PREPARE HDFS")
return connection, upload_to, files
def prepare_s3(
resource_path: PosixPath,
s3_file_connection: S3,
s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath],
):
logger.info("START PREPARE S3")
connection, remote_path = s3_file_df_connection_with_path

s3_file_connection.remove_dir(remote_path, recursive=True)
files = upload_files(resource_path, remote_path, s3_file_connection)

yield connection, remote_path, files

logger.info("START POST-CLEANUP S3")
s3_file_connection.remove_dir(remote_path, recursive=True)
logger.info("END POST-CLEANUP S3")


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -635,14 +650,14 @@ def prepare_clickhouse(
pass

def fill_with_data(df: DataFrame):
logger.info("START PREPARE ORACLE")
logger.info("START PREPARE CLICKHOUSE")
db_writer = DBWriter(
connection=onetl_conn,
target=f"{clickhouse.user}.source_table",
options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"),
)
db_writer.run(df)
logger.info("END PREPARE ORACLE")
logger.info("END PREPARE CLICKHOUSE")

yield onetl_conn, fill_with_data

Expand Down Expand Up @@ -745,7 +760,51 @@ def fill_with_data(df: DataFrame):
pass


@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})])
@pytest.fixture
def prepare_mysql(
mysql_for_conftest: MySQLConnectionDTO,
spark: SparkSession,
):
mysql = mysql_for_conftest
onetl_conn = MySQL(
host=mysql.host,
port=mysql.port,
user=mysql.user,
password=mysql.password,
database=mysql.database_name,
spark=spark,
).check()
try:
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table")
except Exception:
pass
try:
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table")
except Exception:
pass

def fill_with_data(df: DataFrame):
logger.info("START PREPARE MYSQL")
db_writer = DBWriter(
connection=onetl_conn,
target=f"{mysql.database_name}.source_table",
)
db_writer.run(df)
logger.info("END PREPARE MYSQL")

yield onetl_conn, fill_with_data

try:
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table")
except Exception:
pass
try:
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table")
except Exception:
pass


@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {}), ("excel", {})])
def source_file_format(request: FixtureRequest):
name, params = request.param
if name == "csv":
Expand All @@ -769,10 +828,17 @@ def source_file_format(request: FixtureRequest):
**params,
)

if name == "excel":
return "excel", Excel(
header=True,
inferSchema=True,
**params,
)

raise ValueError(f"Unsupported file format: {name}")


@pytest.fixture(params=[("csv", {}), ("jsonline", {})])
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("excel", {})])
def target_file_format(request: FixtureRequest):
name, params = request.param
if name == "csv":
Expand All @@ -791,6 +857,12 @@ def target_file_format(request: FixtureRequest):
**params,
)

if name == "excel":
return "excel", Excel(
header=False,
**params,
)

raise ValueError(f"Unsupported file format: {name}")


Expand Down
28 changes: 27 additions & 1 deletion tests/test_integration/test_run_transfer/test_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from onetl.db import DBReader
from onetl.file import FileDFReader
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp
from pytest import FixtureRequest
from sqlalchemy.ext.asyncio import AsyncSession

Expand Down Expand Up @@ -37,6 +38,7 @@ async def hdfs_to_postgres(
file_format_flavor: str,
):
format_name, file_format = source_file_format
format_name_in_path = "xlsx" if format_name == "excel" else format_name
_, source_path, _ = prepare_hdfs

result = await create_transfer(
Expand All @@ -47,7 +49,7 @@ async def hdfs_to_postgres(
target_connection_id=postgres_connection.id,
source_params={
"type": "hdfs",
"directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor),
"directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor),
"file_format": {
"type": format_name,
**file_format.dict(),
Expand Down Expand Up @@ -121,6 +123,11 @@ async def postgres_to_hdfs(
"without_compression",
id="jsonline",
),
pytest.param(
("excel", {}),
"with_header",
id="excel",
),
],
indirect=["source_file_format", "file_format_flavor"],
)
Expand All @@ -135,6 +142,7 @@ async def test_run_transfer_hdfs_to_postgres(
):
# Arrange
postgres, _ = prepare_postgres
file_format, _ = source_file_format

# Act
result = await client.post(
Expand Down Expand Up @@ -164,6 +172,12 @@ async def test_run_transfer_hdfs_to_postgres(
table="public.target_table",
)
df = reader.run()

# as Excel does not support datetime values with precision greater than milliseconds
if file_format == "excel":
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))

for field in init_df.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand All @@ -183,6 +197,11 @@ async def test_run_transfer_hdfs_to_postgres(
"without_compression",
id="jsonline",
),
pytest.param(
("excel", {}),
"with_header",
id="excel",
),
],
indirect=["target_file_format", "file_format_flavor"],
)
Expand Down Expand Up @@ -235,6 +254,13 @@ async def test_run_transfer_postgres_to_hdfs(
)
df = reader.run()

# as Excel does not support datetime values with precision greater than milliseconds
if format_name == "excel":
init_df = init_df.withColumn(
"REGISTERED_AT",
to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")),
)

for field in init_df.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand Down
8 changes: 4 additions & 4 deletions tests/test_integration/test_run_transfer/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_run_transfer_postgres_to_mssql(
)
df = reader.run()

# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))

Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming(
assert df.columns != init_df_with_mixed_column_naming.columns
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]

# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
df = df.withColumn("Registered At", date_trunc("second", col("Registered At")))
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
"Registered At",
Expand Down Expand Up @@ -228,7 +228,7 @@ async def test_run_transfer_mssql_to_postgres(
)
df = reader.run()

# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))

Expand Down Expand Up @@ -283,7 +283,7 @@ async def test_run_transfer_mssql_to_postgres_mixed_naming(
assert df.columns != init_df_with_mixed_column_naming.columns
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]

# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
df = df.withColumn("Registered At", date_trunc("second", col("Registered At")))
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
"Registered At",
Expand Down
Loading

0 comments on commit 525d635

Please sign in to comment.