diff --git a/docker/Dockerfile.worker b/docker/Dockerfile.worker index 5cc28c70..91dd8e56 100644 --- a/docker/Dockerfile.worker +++ b/docker/Dockerfile.worker @@ -36,7 +36,7 @@ CMD ["--loglevel=info"] FROM prod as test -ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session.get_worker_spark_session +ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session # CI runs tests in the worker container, so we need backend dependencies too RUN poetry install --no-root --extras "worker backend" --with test --without docs,dev diff --git a/docs/changelog/next_release/38.breaking.rst b/docs/changelog/next_release/38.breaking.rst new file mode 100644 index 00000000..ecbf00c7 --- /dev/null +++ b/docs/changelog/next_release/38.breaking.rst @@ -0,0 +1,2 @@ +Pass current ``Run`` to ``CREATE_SPARK_SESSION_FUNCTION``. This allows using run/transfer/group information for Spark session options, +like ``appName`` or custom ones. diff --git a/syncmaster/dto/connections.py b/syncmaster/dto/connections.py index 91877325..46f369b8 100644 --- a/syncmaster/dto/connections.py +++ b/syncmaster/dto/connections.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import ClassVar @dataclass class ConnectionDTO: - pass + type: ClassVar[str] @dataclass @@ -16,7 +17,7 @@ class PostgresConnectionDTO(ConnectionDTO): password: str additional_params: dict database_name: str - type: str = "postgres" + type: ClassVar[str] = "postgres" @dataclass @@ -28,7 +29,7 @@ class OracleConnectionDTO(ConnectionDTO): additional_params: dict sid: str | None = None service_name: str | None = None - type: str = "oracle" + type: ClassVar[str] = "oracle" @dataclass @@ -36,7 +37,7 @@ class HiveConnectionDTO(ConnectionDTO): user: str password: str cluster: str - type: str = "hive" + type: ClassVar[str] = "hive" @dataclass @@ -44,7 +45,7 @@ class HDFSConnectionDTO(ConnectionDTO): user: str password: str cluster: str - type: str = "hdfs" + type: ClassVar[str] = "hdfs" @dataclass @@ -57,4 +58,4 @@ class S3ConnectionDTO(ConnectionDTO): additional_params: dict region: str | None = None protocol: str = "https" - type: str = "s3" + type: ClassVar[str] = "s3" diff --git a/syncmaster/dto/transfers.py b/syncmaster/dto/transfers.py index 44824811..5838cd88 100644 --- a/syncmaster/dto/transfers.py +++ b/syncmaster/dto/transfers.py @@ -1,46 +1,66 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 +import json from dataclasses import dataclass +from typing import ClassVar from syncmaster.schemas.v1.transfers.file_format import CSV, JSON, JSONLine @dataclass class TransferDTO: - pass + type: ClassVar[str] @dataclass -class PostgresTransferDTO(TransferDTO): +class DBTransferDTO(TransferDTO): table_name: str - type: str = "postgres" @dataclass -class OracleTransferDTO(TransferDTO): - table_name: str - type: str = "oracle" +class FileTransferDTO(TransferDTO): + directory_path: str + file_format: CSV | JSONLine | JSON + options: dict + df_schema: dict | None = None + + def __post_init__(self): + if isinstance(self.file_format, dict): + self.file_format = self._get_format(self.file_format) + if isinstance(self.df_schema, str): + self.df_schema = json.loads(self.df_schema) + + def _get_format(self, file_format: dict): + file_type = file_format["type"] + if file_type == "csv": + return CSV.parse_obj(file_format) + if file_type == "jsonline": + return JSONLine.parse_obj(file_format) + if file_type == "json": + return JSON.parse_obj(file_format) + raise ValueError("Unknown file type") @dataclass -class HiveTransferDTO(TransferDTO): - table_name: str - type: str = "hive" +class PostgresTransferDTO(DBTransferDTO): + type: ClassVar[str] = "postgres" @dataclass -class S3TransferDTO(TransferDTO): - directory_path: str - file_format: CSV | JSONLine | JSON - options: dict - df_schema: dict | None = None - type: str = "s3" +class OracleTransferDTO(DBTransferDTO): + type: ClassVar[str] = "oracle" @dataclass -class HDFSTransferDTO(TransferDTO): - directory_path: str - file_format: CSV | JSONLine | JSON - options: dict - df_schema: dict | None = None - type: str = "hdfs" +class HiveTransferDTO(DBTransferDTO): + type: ClassVar[str] = "hive" + + +@dataclass +class S3TransferDTO(FileTransferDTO): + type: ClassVar[str] = "s3" + + +@dataclass +class HDFSTransferDTO(FileTransferDTO): + type: ClassVar[str] = "hdfs" diff --git a/syncmaster/worker/config.py b/syncmaster/worker/config.py index 4ea02c23..81ce1c7e 100644 --- a/syncmaster/worker/config.py +++ b/syncmaster/worker/config.py @@ -12,7 +12,7 @@ broker=settings.build_rabbit_connection_uri(), backend="db+" + settings.build_db_connection_uri(driver="psycopg2"), task_cls=WorkerTask, - imports=[ + include=[ "syncmaster.worker.transfer", ], ) diff --git a/syncmaster/worker/controller.py b/syncmaster/worker/controller.py index 26408e63..3f0b423b 100644 --- a/syncmaster/worker/controller.py +++ b/syncmaster/worker/controller.py @@ -21,11 +21,11 @@ ) from syncmaster.exceptions.connection import ConnectionTypeNotRecognizedError from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.db.hive import HiveHandler +from syncmaster.worker.handlers.db.oracle import OracleHandler +from syncmaster.worker.handlers.db.postgres import PostgresHandler from syncmaster.worker.handlers.file.hdfs import HDFSHandler from syncmaster.worker.handlers.file.s3 import S3Handler -from syncmaster.worker.handlers.hive import HiveHandler -from syncmaster.worker.handlers.oracle import OracleHandler -from syncmaster.worker.handlers.postgres import PostgresHandler logger = logging.getLogger(__name__) @@ -72,6 +72,8 @@ def __init__( target_auth_data: dict, settings: Settings, ): + self.transfer = transfer + self.settings = settings self.source_handler = self.get_handler( connection_data=source_connection.data, transfer_params=transfer.source_params, @@ -82,30 +84,21 @@ def __init__( transfer_params=transfer.target_params, connection_auth_data=target_auth_data, ) - spark = settings.CREATE_SPARK_SESSION_FUNCTION( - settings, - target=self.target_handler.connection_dto, - source=self.source_handler.connection_dto, - ) - - self.source_handler.set_spark(spark) - self.target_handler.set_spark(spark) - logger.info("source connection = %s", self.source_handler) - logger.info("target connection = %s", self.target_handler) def start_transfer(self) -> None: - self.source_handler.init_connection() - self.source_handler.init_reader() - - self.target_handler.init_connection() - self.target_handler.init_writer() - logger.info("Source and target were initialized") + spark = self.settings.CREATE_SPARK_SESSION_FUNCTION( + settings=self.settings, + transfer=self.transfer, + source=self.source_handler.connection_dto, + target=self.target_handler.connection_dto, + ) - df = self.target_handler.normalize_column_name(self.source_handler.read()) - logger.info("Data has been read") + with spark: + self.source_handler.connect(spark) + self.target_handler.connect(spark) - self.target_handler.write(df) - logger.info("Data has been inserted") + df = self.source_handler.read() + self.target_handler.write(df) def get_handler( self, diff --git a/syncmaster/worker/handlers/base.py b/syncmaster/worker/handlers/base.py index e6a4fe98..aa5ab7f7 100644 --- a/syncmaster/worker/handlers/base.py +++ b/syncmaster/worker/handlers/base.py @@ -1,49 +1,31 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -from abc import ABC -from onetl.db import DBReader, DBWriter -from pyspark.sql import SparkSession -from pyspark.sql.dataframe import DataFrame +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING from syncmaster.dto.connections import ConnectionDTO from syncmaster.dto.transfers import TransferDTO +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + class Handler(ABC): def __init__( self, connection_dto: ConnectionDTO, transfer_dto: TransferDTO, - spark: SparkSession | None = None, - ) -> None: - self.spark = spark - self.reader: DBReader | None = None - self.writer: DBWriter | None = None + ): self.connection_dto = connection_dto self.transfer_dto = transfer_dto - def init_connection(self): ... - - def set_spark(self, spark: SparkSession): - self.spark = spark - - def init_reader(self): - if self.connection_dto is None: - raise ValueError("At first you need to initialize connection. Run `init_connection") - - def init_writer(self): - if self.connection_dto is None: - raise ValueError("At first you need to initialize connection. Run `init_connection") - - def read(self) -> DataFrame: - if self.reader is None: - raise ValueError("Reader is not initialized") - return self.reader.run() + @abstractmethod + def connect(self, spark: SparkSession) -> None: ... - def write(self, df: DataFrame) -> None: - if self.writer is None: - raise ValueError("Writer is not initialized") - return self.writer.run(df=df) + @abstractmethod + def read(self) -> DataFrame: ... - def normalize_column_name(self, df: DataFrame) -> DataFrame: ... + @abstractmethod + def write(self, df: DataFrame) -> None: ... diff --git a/syncmaster/worker/handlers/db/__init__.py b/syncmaster/worker/handlers/db/__init__.py new file mode 100644 index 00000000..104aecaf --- /dev/null +++ b/syncmaster/worker/handlers/db/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 diff --git a/syncmaster/worker/handlers/db/base.py b/syncmaster/worker/handlers/db/base.py new file mode 100644 index 00000000..3210e1f4 --- /dev/null +++ b/syncmaster/worker/handlers/db/base.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import TYPE_CHECKING + +from onetl.base import BaseDBConnection +from onetl.db import DBReader, DBWriter + +from syncmaster.dto.transfers import DBTransferDTO +from syncmaster.worker.handlers.base import Handler + +if TYPE_CHECKING: + from pyspark.sql.dataframe import DataFrame + + +class DBHandler(Handler): + connection: BaseDBConnection + transfer_dto: DBTransferDTO + + def read(self) -> DataFrame: + reader = DBReader( + connection=self.connection, + table=self.transfer_dto.table_name, + ) + df = reader.run() + return self.normalize_column_name(df) + + def write(self, df: DataFrame) -> None: + writer = DBWriter( + connection=self.connection, + table=self.transfer_dto.table_name, + ) + return writer.run(df=df) + + @abstractmethod + def normalize_column_name(self, df: DataFrame) -> DataFrame: ... diff --git a/syncmaster/worker/handlers/db/hive.py b/syncmaster/worker/handlers/db/hive.py new file mode 100644 index 00000000..db783249 --- /dev/null +++ b/syncmaster/worker/handlers/db/hive.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + +from onetl.connection import Hive + +from syncmaster.dto.connections import HiveConnectionDTO +from syncmaster.dto.transfers import HiveTransferDTO +from syncmaster.worker.handlers.db.base import DBHandler + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + + +class HiveHandler(DBHandler): + connection: Hive + connection_dto: HiveConnectionDTO + transfer_dto: HiveTransferDTO + + def connect(self, spark: SparkSession): + self.connection = Hive( + cluster=self.connection_dto.cluster, + spark=spark, + ).check() + + def read(self) -> DataFrame: + self.connection.spark.catalog.refreshTable(self.transfer_dto.table_name) + return super().read() + + def normalize_column_name(self, df: DataFrame) -> DataFrame: + for column_name in df.columns: + df = df.withColumnRenamed(column_name, column_name.lower()) + return df diff --git a/syncmaster/worker/handlers/oracle.py b/syncmaster/worker/handlers/db/oracle.py similarity index 61% rename from syncmaster/worker/handlers/oracle.py rename to syncmaster/worker/handlers/db/oracle.py index f8a79ff9..170dc7f3 100644 --- a/syncmaster/worker/handlers/oracle.py +++ b/syncmaster/worker/handlers/db/oracle.py @@ -1,20 +1,25 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + from onetl.connection import Oracle -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame from syncmaster.dto.connections import OracleConnectionDTO from syncmaster.dto.transfers import OracleTransferDTO from syncmaster.worker.handlers.base import Handler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + class OracleHandler(Handler): connection: Oracle connection_dto: OracleConnectionDTO transfer_dto: OracleTransferDTO - def init_connection(self): + def connect(self, spark: SparkSession): self.connection = Oracle( host=self.connection_dto.host, port=self.connection_dto.port, @@ -23,25 +28,9 @@ def init_connection(self): sid=self.connection_dto.sid, service_name=self.connection_dto.service_name, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() - def init_reader(self): - super().init_reader() - df = self.connection.get_df_schema(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - columns=[f'"{f}"' for f in df.fieldNames()], - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: for column_name in df.columns: df = df.withColumnRenamed(column_name, column_name.upper()) diff --git a/syncmaster/worker/handlers/postgres.py b/syncmaster/worker/handlers/db/postgres.py similarity index 60% rename from syncmaster/worker/handlers/postgres.py rename to syncmaster/worker/handlers/db/postgres.py index 25ddf337..4dbae77b 100644 --- a/syncmaster/worker/handlers/postgres.py +++ b/syncmaster/worker/handlers/db/postgres.py @@ -1,20 +1,25 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + from onetl.connection import Postgres -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame from syncmaster.dto.connections import PostgresConnectionDTO from syncmaster.dto.transfers import PostgresTransferDTO from syncmaster.worker.handlers.base import Handler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + class PostgresHandler(Handler): connection: Postgres connection_dto: PostgresConnectionDTO transfer_dto: PostgresTransferDTO - def init_connection(self): + def connect(self, spark: SparkSession): self.connection = Postgres( host=self.connection_dto.host, user=self.connection_dto.user, @@ -22,25 +27,9 @@ def init_connection(self): port=self.connection_dto.port, database=self.connection_dto.database_name, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() - def init_reader(self): - super().init_reader() - df = self.connection.get_df_schema(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - columns=[f'"{f}"' for f in df.fieldNames()], - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: for column_name in df.columns: df = df.withColumnRenamed(column_name, column_name.lower()) diff --git a/syncmaster/worker/handlers/file/base.py b/syncmaster/worker/handlers/file/base.py index d2656d0c..cf77dda0 100644 --- a/syncmaster/worker/handlers/file/base.py +++ b/syncmaster/worker/handlers/file/base.py @@ -1,56 +1,41 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -import json +from typing import TYPE_CHECKING from onetl.base.base_file_df_connection import BaseFileDFConnection from onetl.file import FileDFReader, FileDFWriter -from onetl.file.format import CSV, JSON, JSONLine -from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType from syncmaster.dto.connections import ConnectionDTO -from syncmaster.dto.transfers import TransferDTO +from syncmaster.dto.transfers import FileTransferDTO from syncmaster.worker.handlers.base import Handler +if TYPE_CHECKING: + from pyspark.sql.dataframe import DataFrame + from pyspark.sql.types import StructType + class FileHandler(Handler): connection: BaseFileDFConnection connection_dto: ConnectionDTO - transfer_dto: TransferDTO - - def init_connection(self): ... - - def init_reader(self): - super().init_reader() + transfer_dto: FileTransferDTO - self.reader = FileDFReader( + def read(self) -> DataFrame: + reader = FileDFReader( connection=self.connection, - format=self._get_format(), + format=self.transfer_dto.file_format, source_path=self.transfer_dto.directory_path, - df_schema=StructType.fromJson(json.loads(self.transfer_dto.df_schema)), + df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None, options=self.transfer_dto.options, ) - def init_writer(self): - super().init_writer() + return reader.run() - self.writer = FileDFWriter( + def write(self, df: DataFrame): + writer = FileDFWriter( connection=self.connection, - format=self._get_format(), + format=self.transfer_dto.file_format, target_path=self.transfer_dto.directory_path, options=self.transfer_dto.options, ) - def normalize_column_name(self, df: DataFrame) -> DataFrame: - return df - - def _get_format(self): - file_type = self.transfer_dto.file_format["type"] - if file_type == "csv": - return CSV.parse_obj(self.transfer_dto.file_format) - elif file_type == "jsonline": - return JSONLine.parse_obj(self.transfer_dto.file_format) - elif file_type == "json": - return JSON.parse_obj(self.transfer_dto.file_format) - else: - raise ValueError("Unknown file type") + return writer.run(df=df) diff --git a/syncmaster/worker/handlers/file/hdfs.py b/syncmaster/worker/handlers/file/hdfs.py index ce0a7441..50d5bbf7 100644 --- a/syncmaster/worker/handlers/file/hdfs.py +++ b/syncmaster/worker/handlers/file/hdfs.py @@ -1,14 +1,22 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + from onetl.connection import SparkHDFS +from syncmaster.dto.connections import HDFSConnectionDTO from syncmaster.worker.handlers.file.base import FileHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + class HDFSHandler(FileHandler): - def init_connection(self): + connection_dto: HDFSConnectionDTO + + def connect(self, spark: SparkSession): self.connection = SparkHDFS( cluster=self.connection_dto.cluster, - spark=self.spark, + spark=spark, ).check() diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index 69082541..7296456c 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -1,12 +1,21 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + from onetl.connection import SparkS3 +from syncmaster.dto.connections import S3ConnectionDTO from syncmaster.worker.handlers.file.base import FileHandler +if TYPE_CHECKING: + from pyspark.sql import SparkSession + class S3Handler(FileHandler): - def init_connection(self): + connection_dto: S3ConnectionDTO + + def connect(self, spark: SparkSession): self.connection = SparkS3( host=self.connection_dto.host, port=self.connection_dto.port, @@ -16,5 +25,5 @@ def init_connection(self): protocol=self.connection_dto.protocol, region=self.connection_dto.region, extra=self.connection_dto.additional_params, - spark=self.spark, + spark=spark, ).check() diff --git a/syncmaster/worker/handlers/hive.py b/syncmaster/worker/handlers/hive.py deleted file mode 100644 index 3646e1c2..00000000 --- a/syncmaster/worker/handlers/hive.py +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) -# SPDX-License-Identifier: Apache-2.0 -from onetl.connection import Hive -from onetl.db import DBReader, DBWriter -from pyspark.sql.dataframe import DataFrame - -from syncmaster.dto.connections import HiveConnectionDTO -from syncmaster.dto.transfers import HiveTransferDTO -from syncmaster.worker.handlers.base import Handler - - -class HiveHandler(Handler): - connection: Hive - connection_dto: HiveConnectionDTO - transfer_dto: HiveTransferDTO - - def init_connection(self): - self.connection = Hive( - cluster=self.connection_dto.cluster, - spark=self.spark, - ).check() - - def init_reader(self): - super().init_reader() - self.spark.catalog.refreshTable(self.transfer_dto.table_name) - self.reader = DBReader( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - - def init_writer(self): - super().init_writer() - self.writer = DBWriter( - connection=self.connection, - table=self.transfer_dto.table_name, - ) - - def normalize_column_name(self, df: DataFrame) -> DataFrame: - for column_name in df.columns: - df = df.withColumnRenamed(column_name, column_name.lower()) - return df diff --git a/syncmaster/worker/spark.py b/syncmaster/worker/spark.py index ff9688dc..7c55b1f7 100644 --- a/syncmaster/worker/spark.py +++ b/syncmaster/worker/spark.py @@ -1,24 +1,31 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 import logging +from typing import TYPE_CHECKING -import pyspark from onetl.connection import Oracle, Postgres, SparkS3 -from pyspark.sql import SparkSession from syncmaster.config import Settings +from syncmaster.db.models import Run from syncmaster.dto.connections import ConnectionDTO +if TYPE_CHECKING: + from pyspark.sql import SparkSession + log = logging.getLogger(__name__) def get_worker_spark_session( - settings: Settings, # used in test spark session definition + settings: Settings, + run: Run, source: ConnectionDTO, target: ConnectionDTO, ) -> SparkSession: - """Through the source and target parameters you can get credentials for authorization at the source""" - spark_builder = SparkSession.builder.appName("celery_worker") + """Construct Spark Session using run parameters and application settings""" + from pyspark.sql import SparkSession + + name = run.transfer.group.name + "_" + run.transfer.name + spark_builder = SparkSession.builder.appName(f"syncmaster_{name}") for k, v in get_spark_session_conf(source, target).items(): spark_builder = spark_builder.config(k, v) @@ -36,6 +43,8 @@ def get_packages(db_type: str) -> list[str]: if db_type == "oracle": return Oracle.get_packages() if db_type == "s3": + import pyspark + spark_version = pyspark.__version__ return SparkS3.get_packages(spark_version=spark_version) diff --git a/syncmaster/worker/transfer.py b/syncmaster/worker/transfer.py index e243a2ae..0e604c1c 100644 --- a/syncmaster/worker/transfer.py +++ b/syncmaster/worker/transfer.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 import logging -from datetime import datetime +from datetime import datetime, timezone import onetl from sqlalchemy import select @@ -38,14 +38,16 @@ def run_transfer(session: Session, run_id: int, settings: Settings): Run, run_id, options=( + selectinload(Run.transfer).selectinload(Transfer.group), selectinload(Run.transfer).selectinload(Transfer.source_connection), selectinload(Run.transfer).selectinload(Transfer.target_connection), ), ) if run is None: raise RunNotFoundError + run.status = Status.STARTED - run.started_at = datetime.utcnow() + run.started_at = datetime.now(tz=timezone.utc) session.add(run) session.commit() @@ -73,12 +75,6 @@ def run_transfer(session: Session, run_id: int, settings: Settings): else: run.status = Status.FINISHED logger.warning("Run `%s` was successful", run.id) - finally: - # Both the source and the receiver use the same spark session, - # so it is enough to stop the session at the source. - if controller is not None and controller.source_handler.spark is not None: - controller.source_handler.spark.sparkContext.stop() - controller.source_handler.spark.stop() run.ended_at = datetime.utcnow() session.add(run) diff --git a/tests/spark/__init__.py b/tests/spark/__init__.py index e69de29b..293b7131 100644 --- a/tests/spark/__init__.py +++ b/tests/spark/__init__.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from celery.signals import worker_process_init, worker_process_shutdown +from coverage import Coverage +from onetl.connection import SparkHDFS +from onetl.hooks import hook + +from syncmaster.worker.spark import get_worker_spark_session + +# this is just to automatically import hooks +get_worker_spark_session = get_worker_spark_session + + +@SparkHDFS.Slots.get_cluster_namenodes.bind +@hook +def get_cluster_namenodes(cluster: str) -> set[str] | None: + if cluster == "test-hive": + return {"test-hive"} + return None + + +@SparkHDFS.Slots.is_namenode_active.bind +@hook +def is_namenode_active(host: str, cluster: str) -> bool: + if cluster == "test-hive": + return True + return False + + +@SparkHDFS.Slots.get_ipc_port.bind +@hook +def get_ipc_port(cluster: str) -> int | None: + if cluster == "test-hive": + return 9820 + return None + + +# Needed to collect code coverage by tests in the worker +# https://github.com/nedbat/coveragepy/issues/689#issuecomment-656706935 +COV = None + + +@worker_process_init.connect +def start_coverage(**kwargs): + global COV + + COV = Coverage(data_suffix=True) + COV.start() + + +@worker_process_shutdown.connect +def save_coverage(**kwargs): + if COV is not None: + COV.stop() + COV.save() diff --git a/tests/spark/get_worker_spark_session.py b/tests/spark/get_worker_spark_session.py deleted file mode 100644 index 8c40d97a..00000000 --- a/tests/spark/get_worker_spark_session.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import logging - -from celery.signals import worker_process_init, worker_process_shutdown -from coverage import Coverage -from onetl.connection import SparkHDFS -from onetl.hooks import hook -from pyspark.sql import SparkSession - -from syncmaster.config import Settings -from syncmaster.dto.connections import ConnectionDTO -from syncmaster.worker.spark import get_spark_session_conf - -log = logging.getLogger(__name__) - - -@SparkHDFS.Slots.get_cluster_namenodes.bind -@hook -def get_cluster_namenodes(cluster: str) -> set[str] | None: - if cluster == "test-hive": - return {"test-hive"} - return None - - -@SparkHDFS.Slots.is_namenode_active.bind -@hook -def is_namenode_active(host: str, cluster: str) -> bool: - if cluster == "test-hive": - return True - return False - - -@SparkHDFS.Slots.get_ipc_port.bind -@hook -def get_ipc_port(cluster: str) -> int | None: - if cluster == "test-hive": - return 9820 - return None - - -def get_worker_spark_session( - settings: Settings, - source: ConnectionDTO, - target: ConnectionDTO, -) -> SparkSession: - spark_builder = SparkSession.builder.appName("celery_worker") - - for k, v in get_spark_session_conf(source, target).items(): - spark_builder = spark_builder.config(k, v) - - if source.type == "hive" or target.type == "hive": - log.debug("Enabling Hive support") - spark_builder = spark_builder.enableHiveSupport() - - return spark_builder.getOrCreate() - - -# Needed to collect code coverage by tests in the worker -# https://github.com/nedbat/coveragepy/issues/689#issuecomment-656706935 - - -COV = None - - -@worker_process_init.connect -def start_coverage(**kwargs): - global COV - - COV = Coverage(data_suffix=True) - COV.start() - - -@worker_process_shutdown.connect -def save_coverage(**kwargs): - if COV is not None: - COV.stop() - COV.save()