diff --git a/docker/Dockerfile.worker b/docker/Dockerfile.worker index 7b32bfa8..84830b65 100644 --- a/docker/Dockerfile.worker +++ b/docker/Dockerfile.worker @@ -39,7 +39,7 @@ COPY ./syncmaster/ /app/syncmaster/ FROM base as test -ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session.get_worker_spark_session +ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session # CI runs tests in the worker container, so we need backend dependencies too RUN poetry install --no-root --extras "worker backend" --with test --without docs,dev diff --git a/docs/changelog/next_release/38.breaking.rst b/docs/changelog/next_release/38.breaking.rst new file mode 100644 index 00000000..ecbf00c7 --- /dev/null +++ b/docs/changelog/next_release/38.breaking.rst @@ -0,0 +1,2 @@ +Pass current ``Run`` to ``CREATE_SPARK_SESSION_FUNCTION``. This allows using run/transfer/group information for Spark session options, +like ``appName`` or custom ones. diff --git a/syncmaster/dto/connections.py b/syncmaster/dto/connections.py index 91877325..46f369b8 100644 --- a/syncmaster/dto/connections.py +++ b/syncmaster/dto/connections.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import ClassVar @dataclass class ConnectionDTO: - pass + type: ClassVar[str] @dataclass @@ -16,7 +17,7 @@ class PostgresConnectionDTO(ConnectionDTO): password: str additional_params: dict database_name: str - type: str = "postgres" + type: ClassVar[str] = "postgres" @dataclass @@ -28,7 +29,7 @@ class OracleConnectionDTO(ConnectionDTO): additional_params: dict sid: str | None = None service_name: str | None = None - type: str = "oracle" + type: ClassVar[str] = "oracle" @dataclass @@ -36,7 +37,7 @@ class HiveConnectionDTO(ConnectionDTO): user: str password: str cluster: str - type: str = "hive" + type: ClassVar[str] = "hive" @dataclass @@ -44,7 +45,7 @@ class HDFSConnectionDTO(ConnectionDTO): user: str password: str cluster: str - type: str = "hdfs" + type: ClassVar[str] = "hdfs" @dataclass @@ -57,4 +58,4 @@ class S3ConnectionDTO(ConnectionDTO): additional_params: dict region: str | None = None protocol: str = "https" - type: str = "s3" + type: ClassVar[str] = "s3" diff --git a/syncmaster/dto/transfers.py b/syncmaster/dto/transfers.py index 44824811..5838cd88 100644 --- a/syncmaster/dto/transfers.py +++ b/syncmaster/dto/transfers.py @@ -1,46 +1,66 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 +import json from dataclasses import dataclass +from typing import ClassVar from syncmaster.schemas.v1.transfers.file_format import CSV, JSON, JSONLine @dataclass class TransferDTO: - pass + type: ClassVar[str] @dataclass -class PostgresTransferDTO(TransferDTO): +class DBTransferDTO(TransferDTO): table_name: str - type: str = "postgres" @dataclass -class OracleTransferDTO(TransferDTO): - table_name: str - type: str = "oracle" +class FileTransferDTO(TransferDTO): + directory_path: str + file_format: CSV | JSONLine | JSON + options: dict + df_schema: dict | None = None + + def __post_init__(self): + if isinstance(self.file_format, dict): + self.file_format = self._get_format(self.file_format) + if isinstance(self.df_schema, str): + self.df_schema = json.loads(self.df_schema) + + def _get_format(self, file_format: dict): + file_type = file_format["type"] + if file_type == "csv": + return CSV.parse_obj(file_format) + if file_type == "jsonline": + return JSONLine.parse_obj(file_format) + if file_type == "json": + return JSON.parse_obj(file_format) + raise ValueError("Unknown file type") @dataclass -class HiveTransferDTO(TransferDTO): - table_name: str - type: str = "hive" +class PostgresTransferDTO(DBTransferDTO): + type: ClassVar[str] = "postgres" @dataclass -class S3TransferDTO(TransferDTO): - directory_path: str - file_format: CSV | JSONLine | JSON - options: dict - df_schema: dict | None = None - type: str = "s3" +class OracleTransferDTO(DBTransferDTO): + type: ClassVar[str] = "oracle" @dataclass -class HDFSTransferDTO(TransferDTO): - directory_path: str - file_format: CSV | JSONLine | JSON - options: dict - df_schema: dict | None = None - type: str = "hdfs" +class HiveTransferDTO(DBTransferDTO): + type: ClassVar[str] = "hive" + + +@dataclass +class S3TransferDTO(FileTransferDTO): + type: ClassVar[str] = "s3" + + +@dataclass +class HDFSTransferDTO(FileTransferDTO): + type: ClassVar[str] = "hdfs" diff --git a/syncmaster/worker/controller.py b/syncmaster/worker/controller.py index 26408e63..72eca1b9 100644 --- a/syncmaster/worker/controller.py +++ b/syncmaster/worker/controller.py @@ -4,7 +4,7 @@ from typing import Any from syncmaster.config import Settings -from syncmaster.db.models import Connection, Transfer +from syncmaster.db.models import Connection, Run from syncmaster.dto.connections import ( HDFSConnectionDTO, HiveConnectionDTO, @@ -21,11 +21,11 @@ ) from syncmaster.exceptions.connection import ConnectionTypeNotRecognizedError from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.db.hive import HiveHandler +from syncmaster.worker.handlers.db.oracle import OracleHandler +from syncmaster.worker.handlers.db.postgres import PostgresHandler from syncmaster.worker.handlers.file.hdfs import HDFSHandler from syncmaster.worker.handlers.file.s3 import S3Handler -from syncmaster.worker.handlers.hive import HiveHandler -from syncmaster.worker.handlers.oracle import OracleHandler -from syncmaster.worker.handlers.postgres import PostgresHandler logger = logging.getLogger(__name__) @@ -65,47 +65,40 @@ class TransferController: def __init__( self, - transfer: Transfer, + run: Run, source_connection: Connection, source_auth_data: dict, target_connection: Connection, target_auth_data: dict, settings: Settings, ): + self.run = run + self.settings = settings self.source_handler = self.get_handler( connection_data=source_connection.data, - transfer_params=transfer.source_params, + transfer_params=run.transfer.source_params, connection_auth_data=source_auth_data, ) self.target_handler = self.get_handler( connection_data=target_connection.data, - transfer_params=transfer.target_params, + transfer_params=run.transfer.target_params, connection_auth_data=target_auth_data, ) - spark = settings.CREATE_SPARK_SESSION_FUNCTION( - settings, - target=self.target_handler.connection_dto, + + def perform_transfer(self) -> None: + spark = self.settings.CREATE_SPARK_SESSION_FUNCTION( + settings=self.settings, + run=self.run, source=self.source_handler.connection_dto, + target=self.target_handler.connection_dto, ) - self.source_handler.set_spark(spark) - self.target_handler.set_spark(spark) - logger.info("source connection = %s", self.source_handler) - logger.info("target connection = %s", self.target_handler) - - def start_transfer(self) -> None: - self.source_handler.init_connection() - self.source_handler.init_reader() - - self.target_handler.init_connection() - self.target_handler.init_writer() - logger.info("Source and target were initialized") - - df = self.target_handler.normalize_column_name(self.source_handler.read()) - logger.info("Data has been read") + with spark: + self.source_handler.connect(spark) + self.target_handler.connect(spark) - self.target_handler.write(df) - logger.info("Data has been inserted") + df = self.source_handler.read() + self.target_handler.write(df) def get_handler( self, @@ -114,7 +107,8 @@ def get_handler( transfer_params: dict[str, Any], ) -> Handler: connection_data.update(connection_auth_data) - handler_type = connection_data["type"] + handler_type = connection_data.pop("type") + transfer_params.pop("type", None) if connection_handler_proxy.get(handler_type, None) is None: raise ConnectionTypeNotRecognizedError diff --git a/syncmaster/worker/handlers/base.py b/syncmaster/worker/handlers/base.py index e6a4fe98..0d20a607 100644 --- a/syncmaster/worker/handlers/base.py +++ b/syncmaster/worker/handlers/base.py @@ -1,49 +1,33 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -from abc import ABC -from onetl.db import DBReader, DBWriter -from pyspark.sql import SparkSession -from pyspark.sql.dataframe import DataFrame +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING from syncmaster.dto.connections import ConnectionDTO from syncmaster.dto.transfers import TransferDTO +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + class Handler(ABC): def __init__( self, connection_dto: ConnectionDTO, transfer_dto: TransferDTO, - spark: SparkSession | None = None, - ) -> None: - self.spark = spark - self.reader: DBReader | None = None - self.writer: DBWriter | None = None + ): self.connection_dto = connection_dto self.transfer_dto = transfer_dto - def init_connection(self): ... - - def set_spark(self, spark: SparkSession): - self.spark = spark - - def init_reader(self): - if self.connection_dto is None: - raise ValueError("At first you need to initialize connection. Run `init_connection") - - def init_writer(self): - if self.connection_dto is None: - raise ValueError("At first you need to initialize connection. Run `init_connection") - - def read(self) -> DataFrame: - if self.reader is None: - raise ValueError("Reader is not initialized") - return self.reader.run() + @abstractmethod + def connect(self, spark: SparkSession) -> None: ... - def write(self, df: DataFrame) -> None: - if self.writer is None: - raise ValueError("Writer is not initialized") - return self.writer.run(df=df) + @abstractmethod + def read(self) -> DataFrame: ... - def normalize_column_name(self, df: DataFrame) -> DataFrame: ... + @abstractmethod + def write(self, df: DataFrame) -> None: ... diff --git a/syncmaster/worker/handlers/db/__init__.py b/syncmaster/worker/handlers/db/__init__.py new file mode 100644 index 00000000..104aecaf --- /dev/null +++ b/syncmaster/worker/handlers/db/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 diff --git a/syncmaster/worker/handlers/db/base.py b/syncmaster/worker/handlers/db/base.py new file mode 100644 index 00000000..e476d8bf --- /dev/null +++ b/syncmaster/worker/handlers/db/base.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +from onetl.base import BaseDBConnection +from onetl.db import DBReader, DBWriter + +from syncmaster.dto.transfers import DBTransferDTO +from syncmaster.worker.handlers.base import Handler + +if TYPE_CHECKING: + from pyspark.sql.dataframe import DataFrame + + +class DBHandler(Handler): + connection: BaseDBConnection + transfer_dto: DBTransferDTO + + def read(self) -> DataFrame: + reader = DBReader( + connection=self.connection, + table=self.transfer_dto.table_name, + ) + df = reader.run() + return self.normalize_column_name(df) + + def write(self, df: DataFrame) -> None: + writer = DBWriter( + connection=self.connection, + table=self.transfer_dto.table_name, + ) + return writer.run(df=df) + + @abstractmethod + def normalize_column_name(self, df: DataFrame) -> DataFrame: ... diff --git a/syncmaster/worker/handlers/db/hive.py b/syncmaster/worker/handlers/db/hive.py new file mode 100644 index 00000000..66172da1 --- /dev/null +++ b/syncmaster/worker/handlers/db/hive.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from onetl.connection import Hive + +from syncmaster.dto.connections import HiveConnectionDTO +from syncmaster.dto.transfers import HiveTransferDTO +from syncmaster.worker.handlers.db.base import DBHandler + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + + +class HiveHandler(DBHandler): + connection: Hive + connection_dto: HiveConnectionDTO + transfer_dto: HiveTransferDTO + + def connect(self, spark: SparkSession): + self.connection = Hive( + cluster=self.connection_dto.cluster, + spark=spark, + ).check() + + def read(self) -> DataFrame: + self.connection.spark.catalog.refreshTable(self.transfer_dto.table_name) + return super().read() + + def normalize_column_name(self, df: DataFrame) -> DataFrame: + for column_name in df.columns: + df = df.withColumnRenamed(column_name, column_name.lower()) + return df diff --git a/syncmaster/worker/handlers/oracle.py b/syncmaster/worker/handlers/db/oracle.py similarity index 56% rename from syncmaster/worker/handlers/oracle.py rename to syncmaster/worker/handlers/db/oracle.py index f8a79ff9..708fb576 100644 --- a/syncmaster/worker/handlers/oracle.py +++ b/syncmaster/worker/handlers/db/oracle.py @@ -1,20 +1,27 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + from onetl.connection import Oracle -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame from syncmaster.dto.connections import OracleConnectionDTO from syncmaster.dto.transfers import OracleTransferDTO -from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.db.base import DBHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame -class OracleHandler(Handler): + +class OracleHandler(DBHandler): connection: Oracle connection_dto: OracleConnectionDTO transfer_dto: OracleTransferDTO - def init_connection(self): + def connect(self, spark: SparkSession): self.connection = Oracle( host=self.connection_dto.host, port=self.connection_dto.port, @@ -23,25 +30,9 @@ def init_connection(self): sid=self.connection_dto.sid, service_name=self.connection_dto.service_name, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() - def init_reader(self): - super().init_reader() - df = self.connection.get_df_schema(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - columns=[f'"{f}"' for f in df.fieldNames()], - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: for column_name in df.columns: df = df.withColumnRenamed(column_name, column_name.upper()) diff --git a/syncmaster/worker/handlers/postgres.py b/syncmaster/worker/handlers/db/postgres.py similarity index 55% rename from syncmaster/worker/handlers/postgres.py rename to syncmaster/worker/handlers/db/postgres.py index 25ddf337..07acadf4 100644 --- a/syncmaster/worker/handlers/postgres.py +++ b/syncmaster/worker/handlers/db/postgres.py @@ -1,20 +1,27 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + from onetl.connection import Postgres -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame from syncmaster.dto.connections import PostgresConnectionDTO from syncmaster.dto.transfers import PostgresTransferDTO -from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.db.base import DBHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame -class PostgresHandler(Handler): + +class PostgresHandler(DBHandler): connection: Postgres connection_dto: PostgresConnectionDTO transfer_dto: PostgresTransferDTO - def init_connection(self): + def connect(self, spark: SparkSession): self.connection = Postgres( host=self.connection_dto.host, user=self.connection_dto.user, @@ -22,25 +29,9 @@ def init_connection(self): port=self.connection_dto.port, database=self.connection_dto.database_name, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() - def init_reader(self): - super().init_reader() - df = self.connection.get_df_schema(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - columns=[f'"{f}"' for f in df.fieldNames()], - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: for column_name in df.columns: df = df.withColumnRenamed(column_name, column_name.lower()) diff --git a/syncmaster/worker/handlers/file/base.py b/syncmaster/worker/handlers/file/base.py index d2656d0c..7b32443a 100644 --- a/syncmaster/worker/handlers/file/base.py +++ b/syncmaster/worker/handlers/file/base.py @@ -1,56 +1,45 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -import json + +from __future__ import annotations + +from typing import TYPE_CHECKING from onetl.base.base_file_df_connection import BaseFileDFConnection from onetl.file import FileDFReader, FileDFWriter -from onetl.file.format import CSV, JSON, JSONLine -from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType from syncmaster.dto.connections import ConnectionDTO -from syncmaster.dto.transfers import TransferDTO +from syncmaster.dto.transfers import FileTransferDTO from syncmaster.worker.handlers.base import Handler +if TYPE_CHECKING: + from pyspark.sql.dataframe import DataFrame + class FileHandler(Handler): connection: BaseFileDFConnection connection_dto: ConnectionDTO - transfer_dto: TransferDTO - - def init_connection(self): ... + transfer_dto: FileTransferDTO - def init_reader(self): - super().init_reader() + def read(self) -> DataFrame: + from pyspark.sql.types import StructType - self.reader = FileDFReader( + reader = FileDFReader( connection=self.connection, - format=self._get_format(), + format=self.transfer_dto.file_format, source_path=self.transfer_dto.directory_path, - df_schema=StructType.fromJson(json.loads(self.transfer_dto.df_schema)), + df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None, options=self.transfer_dto.options, ) - def init_writer(self): - super().init_writer() + return reader.run() - self.writer = FileDFWriter( + def write(self, df: DataFrame): + writer = FileDFWriter( connection=self.connection, - format=self._get_format(), + format=self.transfer_dto.file_format, target_path=self.transfer_dto.directory_path, options=self.transfer_dto.options, ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: - return df - - def _get_format(self): - file_type = self.transfer_dto.file_format["type"] - if file_type == "csv": - return CSV.parse_obj(self.transfer_dto.file_format) - elif file_type == "jsonline": - return JSONLine.parse_obj(self.transfer_dto.file_format) - elif file_type == "json": - return JSON.parse_obj(self.transfer_dto.file_format) - else: - raise ValueError("Unknown file type") + return writer.run(df=df) diff --git a/syncmaster/worker/handlers/file/hdfs.py b/syncmaster/worker/handlers/file/hdfs.py index ce0a7441..a80949da 100644 --- a/syncmaster/worker/handlers/file/hdfs.py +++ b/syncmaster/worker/handlers/file/hdfs.py @@ -1,14 +1,24 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + from onetl.connection import SparkHDFS +from syncmaster.dto.connections import HDFSConnectionDTO from syncmaster.worker.handlers.file.base import FileHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + class HDFSHandler(FileHandler): - def init_connection(self): + connection_dto: HDFSConnectionDTO + + def connect(self, spark: SparkSession): self.connection = SparkHDFS( cluster=self.connection_dto.cluster, - spark=self.spark, + spark=spark, ).check() diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index 69082541..b3c085e1 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -1,12 +1,23 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + from onetl.connection import SparkS3 +from syncmaster.dto.connections import S3ConnectionDTO from syncmaster.worker.handlers.file.base import FileHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + class S3Handler(FileHandler): - def init_connection(self): + connection_dto: S3ConnectionDTO + + def connect(self, spark: SparkSession): self.connection = SparkS3( host=self.connection_dto.host, port=self.connection_dto.port, @@ -16,5 +27,5 @@ def init_connection(self): protocol=self.connection_dto.protocol, region=self.connection_dto.region, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() diff --git a/syncmaster/worker/handlers/hive.py b/syncmaster/worker/handlers/hive.py deleted file mode 100644 index 3646e1c2..00000000 --- a/syncmaster/worker/handlers/hive.py +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) -# SPDX-License-Identifier: Apache-2.0 -from onetl.connection import Hive -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame - -from syncmaster.dto.connections import HiveConnectionDTO -from syncmaster.dto.transfers import HiveTransferDTO -from syncmaster.worker.handlers.base import Handler - - -class HiveHandler(Handler): - connection: Hive - connection_dto: HiveConnectionDTO - transfer_dto: HiveTransferDTO - - def init_connection(self): - self.connection = Hive( - cluster=self.connection_dto.cluster, - spark=self.spark, - ).check() - - def init_reader(self): - super().init_reader() - self.spark.catalog.refreshTable(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - - def normalize_column_name(self, df: DataFrame) -> DataFrame: - for column_name in df.columns: - df = df.withColumnRenamed(column_name, column_name.lower()) - return df diff --git a/syncmaster/worker/spark.py b/syncmaster/worker/spark.py index ff9688dc..b4882f57 100644 --- a/syncmaster/worker/spark.py +++ b/syncmaster/worker/spark.py @@ -1,24 +1,33 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -import pyspark from onetl.connection import Oracle, Postgres, SparkS3 -from pyspark.sql import SparkSession from syncmaster.config import Settings +from syncmaster.db.models import Run from syncmaster.dto.connections import ConnectionDTO +if TYPE_CHECKING: + from pyspark.sql import SparkSession + log = logging.getLogger(__name__) def get_worker_spark_session( - settings: Settings, # used in test spark session definition + settings: Settings, + run: Run, source: ConnectionDTO, target: ConnectionDTO, ) -> SparkSession: - """Through the source and target parameters you can get credentials for authorization at the source""" - spark_builder = SparkSession.builder.appName("celery_worker") + """Construct Spark Session using run parameters and application settings""" + from pyspark.sql import SparkSession + + name = run.transfer.group.name + "_" + run.transfer.name + spark_builder = SparkSession.builder.appName(f"syncmaster_{name}") for k, v in get_spark_session_conf(source, target).items(): spark_builder = spark_builder.config(k, v) @@ -36,6 +45,8 @@ def get_packages(db_type: str) -> list[str]: if db_type == "oracle": return Oracle.get_packages() if db_type == "s3": + import pyspark + spark_version = pyspark.__version__ return SparkS3.get_packages(spark_version=spark_version) diff --git a/syncmaster/worker/transfer.py b/syncmaster/worker/transfer.py index e243a2ae..25c8e667 100644 --- a/syncmaster/worker/transfer.py +++ b/syncmaster/worker/transfer.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 import logging -from datetime import datetime +from datetime import datetime, timezone import onetl from sqlalchemy import select @@ -21,9 +21,6 @@ @celery.task(name="run_transfer_task", bind=True, track_started=True) def run_transfer_task(self: WorkerTask, run_id: int) -> None: onetl.log.setup_logging(level=logging.INFO) - """Task for make transfer data""" - logger.info("Before spark initializing") - logger.info("Spark initialized") with Session(self.engine) as session: run_transfer( session=session, @@ -33,24 +30,25 @@ def run_transfer_task(self: WorkerTask, run_id: int) -> None: def run_transfer(session: Session, run_id: int, settings: Settings): - logger.info("Start transfering") + logger.info("Start transfer") run = session.get( Run, run_id, options=( + selectinload(Run.transfer), + selectinload(Run.transfer).selectinload(Transfer.group), selectinload(Run.transfer).selectinload(Transfer.source_connection), selectinload(Run.transfer).selectinload(Transfer.target_connection), ), ) if run is None: raise RunNotFoundError + run.status = Status.STARTED - run.started_at = datetime.utcnow() + run.started_at = datetime.now(tz=timezone.utc) session.add(run) session.commit() - controller = None - q_source_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.source_connection.id) q_target_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.target_connection.id) @@ -59,27 +57,21 @@ def run_transfer(session: Session, run_id: int, settings: Settings): try: controller = TransferController( - transfer=run.transfer, + run=run, source_connection=run.transfer.source_connection, target_connection=run.transfer.target_connection, source_auth_data=source_auth_data, target_auth_data=target_auth_data, settings=settings, ) - controller.start_transfer() + controller.perform_transfer() except Exception: run.status = Status.FAILED - logger.exception("Run `%s` was failed", run.id) + logger.exception("Run %r was failed", run.id) else: run.status = Status.FINISHED - logger.warning("Run `%s` was successful", run.id) - finally: - # Both the source and the receiver use the same spark session, - # so it is enough to stop the session at the source. - if controller is not None and controller.source_handler.spark is not None: - controller.source_handler.spark.sparkContext.stop() - controller.source_handler.spark.stop() + logger.warning("Run %r was successful", run.id) - run.ended_at = datetime.utcnow() + run.ended_at = datetime.now(tz=timezone.utc) session.add(run) session.commit() diff --git a/tests/spark/__init__.py b/tests/spark/__init__.py index e69de29b..293b7131 100644 --- a/tests/spark/__init__.py +++ b/tests/spark/__init__.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from celery.signals import worker_process_init, worker_process_shutdown +from coverage import Coverage +from onetl.connection import SparkHDFS +from onetl.hooks import hook + +from syncmaster.worker.spark import get_worker_spark_session + +# this is just to automatically import hooks +get_worker_spark_session = get_worker_spark_session + + +@SparkHDFS.Slots.get_cluster_namenodes.bind +@hook +def get_cluster_namenodes(cluster: str) -> set[str] | None: + if cluster == "test-hive": + return {"test-hive"} + return None + + +@SparkHDFS.Slots.is_namenode_active.bind +@hook +def is_namenode_active(host: str, cluster: str) -> bool: + if cluster == "test-hive": + return True + return False + + +@SparkHDFS.Slots.get_ipc_port.bind +@hook +def get_ipc_port(cluster: str) -> int | None: + if cluster == "test-hive": + return 9820 + return None + + +# Needed to collect code coverage by tests in the worker +# https://github.com/nedbat/coveragepy/issues/689#issuecomment-656706935 +COV = None + + +@worker_process_init.connect +def start_coverage(**kwargs): + global COV + + COV = Coverage(data_suffix=True) + COV.start() + + +@worker_process_shutdown.connect +def save_coverage(**kwargs): + if COV is not None: + COV.stop() + COV.save() diff --git a/tests/spark/get_worker_spark_session.py b/tests/spark/get_worker_spark_session.py deleted file mode 100644 index 8c40d97a..00000000 --- a/tests/spark/get_worker_spark_session.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import logging - -from celery.signals import worker_process_init, worker_process_shutdown -from coverage import Coverage -from onetl.connection import SparkHDFS -from onetl.hooks import hook -from pyspark.sql import SparkSession - -from syncmaster.config import Settings -from syncmaster.dto.connections import ConnectionDTO -from syncmaster.worker.spark import get_spark_session_conf - -log = logging.getLogger(__name__) - - -@SparkHDFS.Slots.get_cluster_namenodes.bind -@hook -def get_cluster_namenodes(cluster: str) -> set[str] | None: - if cluster == "test-hive": - return {"test-hive"} - return None - - -@SparkHDFS.Slots.is_namenode_active.bind -@hook -def is_namenode_active(host: str, cluster: str) -> bool: - if cluster == "test-hive": - return True - return False - - -@SparkHDFS.Slots.get_ipc_port.bind -@hook -def get_ipc_port(cluster: str) -> int | None: - if cluster == "test-hive": - return 9820 - return None - - -def get_worker_spark_session( - settings: Settings, - source: ConnectionDTO, - target: ConnectionDTO, -) -> SparkSession: - spark_builder = SparkSession.builder.appName("celery_worker") - - for k, v in get_spark_session_conf(source, target).items(): - spark_builder = spark_builder.config(k, v) - - if source.type == "hive" or target.type == "hive": - log.debug("Enabling Hive support") - spark_builder = spark_builder.enableHiveSupport() - - return spark_builder.getOrCreate() - - -# Needed to collect code coverage by tests in the worker -# https://github.com/nedbat/coveragepy/issues/689#issuecomment-656706935 - - -COV = None - - -@worker_process_init.connect -def start_coverage(**kwargs): - global COV - - COV = Coverage(data_suffix=True) - COV.start() - - -@worker_process_shutdown.connect -def save_coverage(**kwargs): - if COV is not None: - COV.stop() - COV.save() diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 58f6fcf7..8233d102 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -98,7 +98,6 @@ def get_spark_session(connection_settings: Settings) -> SparkSession: ) def hive(test_settings: TestSettings) -> HiveConnectionDTO: return HiveConnectionDTO( - type="hive", cluster=test_settings.TEST_HIVE_CLUSTER, user=test_settings.TEST_HIVE_USER, password=test_settings.TEST_HIVE_PASSWORD, @@ -111,7 +110,6 @@ def hive(test_settings: TestSettings) -> HiveConnectionDTO: ) def hdfs(test_settings: TestSettings) -> HDFSConnectionDTO: return HDFSConnectionDTO( - type="hdfs", cluster=test_settings.TEST_HIVE_CLUSTER, user=test_settings.TEST_HIVE_USER, password=test_settings.TEST_HIVE_PASSWORD, @@ -124,7 +122,6 @@ def hdfs(test_settings: TestSettings) -> HDFSConnectionDTO: ) def oracle(test_settings: TestSettings) -> OracleConnectionDTO: return OracleConnectionDTO( - type="oracle", host=test_settings.TEST_ORACLE_HOST, port=test_settings.TEST_ORACLE_PORT, user=test_settings.TEST_ORACLE_USER, @@ -141,7 +138,6 @@ def oracle(test_settings: TestSettings) -> OracleConnectionDTO: ) def postgres(test_settings: TestSettings) -> PostgresConnectionDTO: return PostgresConnectionDTO( - type="postgres", host=test_settings.TEST_POSTGRES_HOST, port=test_settings.TEST_POSTGRES_PORT, user=test_settings.TEST_POSTGRES_USER, @@ -157,7 +153,6 @@ def postgres(test_settings: TestSettings) -> PostgresConnectionDTO: ) def s3(test_settings: TestSettings) -> S3ConnectionDTO: return S3ConnectionDTO( - type="s3", host=test_settings.TEST_S3_HOST, port=test_settings.TEST_S3_PORT, bucket=test_settings.TEST_S3_BUCKET,