Skip to content

Commit

Permalink
[DOP-15023] Pass Run to CREATE_SPARK_SESSION_FUNCTION
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Apr 22, 2024
1 parent 67a17b0 commit 19b5e3c
Show file tree
Hide file tree
Showing 20 changed files with 318 additions and 313 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile.worker
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ COPY ./syncmaster/ /app/syncmaster/

FROM base as test

ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session.get_worker_spark_session
ENV CREATE_SPARK_SESSION_FUNCTION=tests.spark.get_worker_spark_session

# CI runs tests in the worker container, so we need backend dependencies too
RUN poetry install --no-root --extras "worker backend" --with test --without docs,dev
Expand Down
2 changes: 2 additions & 0 deletions docs/changelog/next_release/38.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Pass current ``Run`` to ``CREATE_SPARK_SESSION_FUNCTION``. This allows using run/transfer/group information for Spark session options,
like ``appName`` or custom ones.
13 changes: 7 additions & 6 deletions syncmaster/dto/connections.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import ClassVar


@dataclass
class ConnectionDTO:
pass
type: ClassVar[str]


@dataclass
Expand All @@ -16,7 +17,7 @@ class PostgresConnectionDTO(ConnectionDTO):
password: str
additional_params: dict
database_name: str
type: str = "postgres"
type: ClassVar[str] = "postgres"


@dataclass
Expand All @@ -28,23 +29,23 @@ class OracleConnectionDTO(ConnectionDTO):
additional_params: dict
sid: str | None = None
service_name: str | None = None
type: str = "oracle"
type: ClassVar[str] = "oracle"


@dataclass
class HiveConnectionDTO(ConnectionDTO):
user: str
password: str
cluster: str
type: str = "hive"
type: ClassVar[str] = "hive"


@dataclass
class HDFSConnectionDTO(ConnectionDTO):
user: str
password: str
cluster: str
type: str = "hdfs"
type: ClassVar[str] = "hdfs"


@dataclass
Expand All @@ -57,4 +58,4 @@ class S3ConnectionDTO(ConnectionDTO):
additional_params: dict
region: str | None = None
protocol: str = "https"
type: str = "s3"
type: ClassVar[str] = "s3"
62 changes: 41 additions & 21 deletions syncmaster/dto/transfers.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,66 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
import json
from dataclasses import dataclass
from typing import ClassVar

from syncmaster.schemas.v1.transfers.file_format import CSV, JSON, JSONLine


@dataclass
class TransferDTO:
pass
type: ClassVar[str]


@dataclass
class PostgresTransferDTO(TransferDTO):
class DBTransferDTO(TransferDTO):
table_name: str
type: str = "postgres"


@dataclass
class OracleTransferDTO(TransferDTO):
table_name: str
type: str = "oracle"
class FileTransferDTO(TransferDTO):
directory_path: str
file_format: CSV | JSONLine | JSON
options: dict
df_schema: dict | None = None

def __post_init__(self):
if isinstance(self.file_format, dict):
self.file_format = self._get_format(self.file_format)
if isinstance(self.df_schema, str):
self.df_schema = json.loads(self.df_schema)

def _get_format(self, file_format: dict):
file_type = file_format["type"]
if file_type == "csv":
return CSV.parse_obj(file_format)
if file_type == "jsonline":
return JSONLine.parse_obj(file_format)
if file_type == "json":
return JSON.parse_obj(file_format)
raise ValueError("Unknown file type")


@dataclass
class HiveTransferDTO(TransferDTO):
table_name: str
type: str = "hive"
class PostgresTransferDTO(DBTransferDTO):
type: ClassVar[str] = "postgres"


@dataclass
class S3TransferDTO(TransferDTO):
directory_path: str
file_format: CSV | JSONLine | JSON
options: dict
df_schema: dict | None = None
type: str = "s3"
class OracleTransferDTO(DBTransferDTO):
type: ClassVar[str] = "oracle"


@dataclass
class HDFSTransferDTO(TransferDTO):
directory_path: str
file_format: CSV | JSONLine | JSON
options: dict
df_schema: dict | None = None
type: str = "hdfs"
class HiveTransferDTO(DBTransferDTO):
type: ClassVar[str] = "hive"


@dataclass
class S3TransferDTO(FileTransferDTO):
type: ClassVar[str] = "s3"


@dataclass
class HDFSTransferDTO(FileTransferDTO):
type: ClassVar[str] = "hdfs"
50 changes: 22 additions & 28 deletions syncmaster/worker/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from syncmaster.config import Settings
from syncmaster.db.models import Connection, Transfer
from syncmaster.db.models import Connection, Run
from syncmaster.dto.connections import (
HDFSConnectionDTO,
HiveConnectionDTO,
Expand All @@ -21,11 +21,11 @@
)
from syncmaster.exceptions.connection import ConnectionTypeNotRecognizedError
from syncmaster.worker.handlers.base import Handler
from syncmaster.worker.handlers.db.hive import HiveHandler
from syncmaster.worker.handlers.db.oracle import OracleHandler
from syncmaster.worker.handlers.db.postgres import PostgresHandler
from syncmaster.worker.handlers.file.hdfs import HDFSHandler
from syncmaster.worker.handlers.file.s3 import S3Handler
from syncmaster.worker.handlers.hive import HiveHandler
from syncmaster.worker.handlers.oracle import OracleHandler
from syncmaster.worker.handlers.postgres import PostgresHandler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,47 +65,40 @@ class TransferController:

def __init__(
self,
transfer: Transfer,
run: Run,
source_connection: Connection,
source_auth_data: dict,
target_connection: Connection,
target_auth_data: dict,
settings: Settings,
):
self.run = run
self.settings = settings
self.source_handler = self.get_handler(
connection_data=source_connection.data,
transfer_params=transfer.source_params,
transfer_params=run.transfer.source_params,
connection_auth_data=source_auth_data,
)
self.target_handler = self.get_handler(
connection_data=target_connection.data,
transfer_params=transfer.target_params,
transfer_params=run.transfer.target_params,
connection_auth_data=target_auth_data,
)
spark = settings.CREATE_SPARK_SESSION_FUNCTION(
settings,
target=self.target_handler.connection_dto,

def perform_transfer(self) -> None:
spark = self.settings.CREATE_SPARK_SESSION_FUNCTION(
settings=self.settings,
run=self.run,
source=self.source_handler.connection_dto,
target=self.target_handler.connection_dto,
)

self.source_handler.set_spark(spark)
self.target_handler.set_spark(spark)
logger.info("source connection = %s", self.source_handler)
logger.info("target connection = %s", self.target_handler)

def start_transfer(self) -> None:
self.source_handler.init_connection()
self.source_handler.init_reader()

self.target_handler.init_connection()
self.target_handler.init_writer()
logger.info("Source and target were initialized")

df = self.target_handler.normalize_column_name(self.source_handler.read())
logger.info("Data has been read")
with spark:
self.source_handler.connect(spark)
self.target_handler.connect(spark)

self.target_handler.write(df)
logger.info("Data has been inserted")
df = self.source_handler.read()
self.target_handler.write(df)

def get_handler(
self,
Expand All @@ -114,7 +107,8 @@ def get_handler(
transfer_params: dict[str, Any],
) -> Handler:
connection_data.update(connection_auth_data)
handler_type = connection_data["type"]
handler_type = connection_data.pop("type")
transfer_params.pop("type", None)

if connection_handler_proxy.get(handler_type, None) is None:
raise ConnectionTypeNotRecognizedError
Expand Down
46 changes: 15 additions & 31 deletions syncmaster/worker/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,33 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
from abc import ABC

from onetl.db import DBReader, DBWriter
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from syncmaster.dto.connections import ConnectionDTO
from syncmaster.dto.transfers import TransferDTO

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame


class Handler(ABC):
def __init__(
self,
connection_dto: ConnectionDTO,
transfer_dto: TransferDTO,
spark: SparkSession | None = None,
) -> None:
self.spark = spark
self.reader: DBReader | None = None
self.writer: DBWriter | None = None
):
self.connection_dto = connection_dto
self.transfer_dto = transfer_dto

def init_connection(self): ...

def set_spark(self, spark: SparkSession):
self.spark = spark

def init_reader(self):
if self.connection_dto is None:
raise ValueError("At first you need to initialize connection. Run `init_connection")

def init_writer(self):
if self.connection_dto is None:
raise ValueError("At first you need to initialize connection. Run `init_connection")

def read(self) -> DataFrame:
if self.reader is None:
raise ValueError("Reader is not initialized")
return self.reader.run()
@abstractmethod
def connect(self, spark: SparkSession) -> None: ...

def write(self, df: DataFrame) -> None:
if self.writer is None:
raise ValueError("Writer is not initialized")
return self.writer.run(df=df)
@abstractmethod
def read(self) -> DataFrame: ...

def normalize_column_name(self, df: DataFrame) -> DataFrame: ...
@abstractmethod
def write(self, df: DataFrame) -> None: ...
2 changes: 2 additions & 0 deletions syncmaster/worker/handlers/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
39 changes: 39 additions & 0 deletions syncmaster/worker/handlers/db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING

from onetl.base import BaseDBConnection
from onetl.db import DBReader, DBWriter

from syncmaster.dto.transfers import DBTransferDTO
from syncmaster.worker.handlers.base import Handler

if TYPE_CHECKING:
from pyspark.sql.dataframe import DataFrame


class DBHandler(Handler):
connection: BaseDBConnection
transfer_dto: DBTransferDTO

def read(self) -> DataFrame:
reader = DBReader(
connection=self.connection,
table=self.transfer_dto.table_name,
)
df = reader.run()
return self.normalize_column_name(df)

def write(self, df: DataFrame) -> None:
writer = DBWriter(
connection=self.connection,
table=self.transfer_dto.table_name,
)
return writer.run(df=df)

@abstractmethod
def normalize_column_name(self, df: DataFrame) -> DataFrame: ...
37 changes: 37 additions & 0 deletions syncmaster/worker/handlers/db/hive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

from onetl.connection import Hive

from syncmaster.dto.connections import HiveConnectionDTO
from syncmaster.dto.transfers import HiveTransferDTO
from syncmaster.worker.handlers.db.base import DBHandler

if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame


class HiveHandler(DBHandler):
connection: Hive
connection_dto: HiveConnectionDTO
transfer_dto: HiveTransferDTO

def connect(self, spark: SparkSession):
self.connection = Hive(
cluster=self.connection_dto.cluster,
spark=spark,
).check()

def read(self) -> DataFrame:
self.connection.spark.catalog.refreshTable(self.transfer_dto.table_name)
return super().read()

def normalize_column_name(self, df: DataFrame) -> DataFrame:
for column_name in df.columns:
df = df.withColumnRenamed(column_name, column_name.lower())
return df
Loading

0 comments on commit 19b5e3c

Please sign in to comment.