From c653da499f15364cbe6cd8a1bc64aab899f580ef Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Thu, 5 Dec 2024 10:55:09 +0300 Subject: [PATCH] [DOP-21665] - add spark dialect extension to clickhouse --- docs/changelog/next_release/160.feature.rst | 1 + syncmaster/worker/handlers/db/clickhouse.py | 4 ++++ syncmaster/worker/spark.py | 6 ++++-- .../test_run_transfer/conftest.py | 1 + .../test_run_transfer/test_clickhouse.py | 15 --------------- 5 files changed, 10 insertions(+), 17 deletions(-) create mode 100644 docs/changelog/next_release/160.feature.rst diff --git a/docs/changelog/next_release/160.feature.rst b/docs/changelog/next_release/160.feature.rst new file mode 100644 index 00000000..32806483 --- /dev/null +++ b/docs/changelog/next_release/160.feature.rst @@ -0,0 +1 @@ +Add `spark-dialect-extension `_ diff --git a/syncmaster/worker/handlers/db/clickhouse.py b/syncmaster/worker/handlers/db/clickhouse.py index c033a431..ad7cc25d 100644 --- a/syncmaster/worker/handlers/db/clickhouse.py +++ b/syncmaster/worker/handlers/db/clickhouse.py @@ -23,6 +23,10 @@ class ClickhouseHandler(DBHandler): transfer_dto: ClickhouseTransferDTO def connect(self, spark: SparkSession): + ClickhouseDialectRegistry = ( + spark._jvm.io.github.mtsongithub.doetl.sparkdialectextensions.clickhouse.ClickhouseDialectRegistry + ) + ClickhouseDialectRegistry.register() self.connection = Clickhouse( host=self.connection_dto.host, port=self.connection_dto.port, diff --git a/syncmaster/worker/spark.py b/syncmaster/worker/spark.py index e37ba442..c8a6be48 100644 --- a/syncmaster/worker/spark.py +++ b/syncmaster/worker/spark.py @@ -50,8 +50,10 @@ def get_packages(db_type: str) -> list[str]: if db_type == "oracle": return Oracle.get_packages() if db_type == "clickhouse": - # TODO: add https://github.com/MobileTeleSystems/spark-dialect-extension/ to spark jars - return Clickhouse.get_packages() + return [ + "io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2", + *Clickhouse.get_packages(), + ] if db_type == "mssql": return MSSQL.get_packages() if db_type == "mysql": diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 94b821b3..9211d5f2 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -78,6 +78,7 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession: maven_packages.extend(Oracle.get_packages()) if "clickhouse" in markers: + maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2") maven_packages.extend(Clickhouse.get_packages()) if "mssql" in markers: diff --git a/tests/test_integration/test_run_transfer/test_clickhouse.py b/tests/test_integration/test_run_transfer/test_clickhouse.py index d76a5e7e..eb651db5 100644 --- a/tests/test_integration/test_run_transfer/test_clickhouse.py +++ b/tests/test_integration/test_run_transfer/test_clickhouse.py @@ -6,7 +6,6 @@ from onetl.connection import Clickhouse from onetl.db import DBReader from pyspark.sql import DataFrame -from pyspark.sql.functions import col, date_trunc from sqlalchemy.ext.asyncio import AsyncSession from syncmaster.db.models import Connection, Group, Queue, Status, Transfer @@ -117,8 +116,6 @@ async def test_run_transfer_postgres_to_clickhouse( table=f"{clickhouse.user}.target_table", ) df = reader.run() - # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 - init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) @@ -169,11 +166,6 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming( assert df.columns != init_df_with_mixed_column_naming.columns assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] - # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 - init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( - "Registered At", - date_trunc("second", col("Registered At")), - ) for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) @@ -222,8 +214,6 @@ async def test_run_transfer_clickhouse_to_postgres( ) df = reader.run() - # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 - init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) for field in init_df.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) @@ -275,11 +265,6 @@ async def test_run_transfer_clickhouse_to_postgres_mixed_naming( assert df.columns != init_df_with_mixed_column_naming.columns assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] - # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 - init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( - "Registered At", - date_trunc("second", col("Registered At")), - ) for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType))