-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DOP-15023] Pass Run to CREATE_SPARK_SESSION_FUNCTION
- Loading branch information
Showing
21 changed files
with
320 additions
and
316 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Pass current ``Run`` to ``CREATE_SPARK_SESSION_FUNCTION``. This allows using run/transfer/group information for Spark session options, | ||
like ``appName`` or custom ones. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,66 @@ | ||
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import json | ||
from dataclasses import dataclass | ||
from typing import ClassVar | ||
|
||
from syncmaster.schemas.v1.transfers.file_format import CSV, JSON, JSONLine | ||
|
||
|
||
@dataclass | ||
class TransferDTO: | ||
pass | ||
type: ClassVar[str] | ||
|
||
|
||
@dataclass | ||
class PostgresTransferDTO(TransferDTO): | ||
class DBTransferDTO(TransferDTO): | ||
table_name: str | ||
type: str = "postgres" | ||
|
||
|
||
@dataclass | ||
class OracleTransferDTO(TransferDTO): | ||
table_name: str | ||
type: str = "oracle" | ||
class FileTransferDTO(TransferDTO): | ||
directory_path: str | ||
file_format: CSV | JSONLine | JSON | ||
options: dict | ||
df_schema: dict | None = None | ||
|
||
def __post_init__(self): | ||
if isinstance(self.file_format, dict): | ||
self.file_format = self._get_format(self.file_format) | ||
if isinstance(self.df_schema, str): | ||
self.df_schema = json.loads(self.df_schema) | ||
|
||
def _get_format(self, file_format: dict): | ||
file_type = file_format["type"] | ||
if file_type == "csv": | ||
return CSV.parse_obj(file_format) | ||
if file_type == "jsonline": | ||
return JSONLine.parse_obj(file_format) | ||
if file_type == "json": | ||
return JSON.parse_obj(file_format) | ||
raise ValueError("Unknown file type") | ||
|
||
|
||
@dataclass | ||
class HiveTransferDTO(TransferDTO): | ||
table_name: str | ||
type: str = "hive" | ||
class PostgresTransferDTO(DBTransferDTO): | ||
type: ClassVar[str] = "postgres" | ||
|
||
|
||
@dataclass | ||
class S3TransferDTO(TransferDTO): | ||
directory_path: str | ||
file_format: CSV | JSONLine | JSON | ||
options: dict | ||
df_schema: dict | None = None | ||
type: str = "s3" | ||
class OracleTransferDTO(DBTransferDTO): | ||
type: ClassVar[str] = "oracle" | ||
|
||
|
||
@dataclass | ||
class HDFSTransferDTO(TransferDTO): | ||
directory_path: str | ||
file_format: CSV | JSONLine | JSON | ||
options: dict | ||
df_schema: dict | None = None | ||
type: str = "hdfs" | ||
class HiveTransferDTO(DBTransferDTO): | ||
type: ClassVar[str] = "hive" | ||
|
||
|
||
@dataclass | ||
class S3TransferDTO(FileTransferDTO): | ||
type: ClassVar[str] = "s3" | ||
|
||
|
||
@dataclass | ||
class HDFSTransferDTO(FileTransferDTO): | ||
type: ClassVar[str] = "hdfs" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,33 @@ | ||
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from abc import ABC | ||
|
||
from onetl.db import DBReader, DBWriter | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.dataframe import DataFrame | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from syncmaster.dto.connections import ConnectionDTO | ||
from syncmaster.dto.transfers import TransferDTO | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.dataframe import DataFrame | ||
|
||
|
||
class Handler(ABC): | ||
def __init__( | ||
self, | ||
connection_dto: ConnectionDTO, | ||
transfer_dto: TransferDTO, | ||
spark: SparkSession | None = None, | ||
) -> None: | ||
self.spark = spark | ||
self.reader: DBReader | None = None | ||
self.writer: DBWriter | None = None | ||
): | ||
self.connection_dto = connection_dto | ||
self.transfer_dto = transfer_dto | ||
|
||
def init_connection(self): ... | ||
|
||
def set_spark(self, spark: SparkSession): | ||
self.spark = spark | ||
|
||
def init_reader(self): | ||
if self.connection_dto is None: | ||
raise ValueError("At first you need to initialize connection. Run `init_connection") | ||
|
||
def init_writer(self): | ||
if self.connection_dto is None: | ||
raise ValueError("At first you need to initialize connection. Run `init_connection") | ||
|
||
def read(self) -> DataFrame: | ||
if self.reader is None: | ||
raise ValueError("Reader is not initialized") | ||
return self.reader.run() | ||
@abstractmethod | ||
def connect(self, spark: SparkSession) -> None: ... | ||
|
||
def write(self, df: DataFrame) -> None: | ||
if self.writer is None: | ||
raise ValueError("Writer is not initialized") | ||
return self.writer.run(df=df) | ||
@abstractmethod | ||
def read(self) -> DataFrame: ... | ||
|
||
def normalize_column_name(self, df: DataFrame) -> DataFrame: ... | ||
@abstractmethod | ||
def write(self, df: DataFrame) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from onetl.base import BaseDBConnection | ||
from onetl.db import DBReader, DBWriter | ||
|
||
from syncmaster.dto.transfers import DBTransferDTO | ||
from syncmaster.worker.handlers.base import Handler | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql.dataframe import DataFrame | ||
|
||
|
||
class DBHandler(Handler): | ||
connection: BaseDBConnection | ||
transfer_dto: DBTransferDTO | ||
|
||
def read(self) -> DataFrame: | ||
reader = DBReader( | ||
connection=self.connection, | ||
table=self.transfer_dto.table_name, | ||
) | ||
return reader.run() | ||
|
||
def write(self, df: DataFrame) -> None: | ||
writer = DBWriter( | ||
connection=self.connection, | ||
table=self.transfer_dto.table_name, | ||
) | ||
return writer.run(df=self.normalize_column_names(df)) | ||
|
||
@abstractmethod | ||
def normalize_column_names(self, df: DataFrame) -> DataFrame: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from onetl.connection import Hive | ||
|
||
from syncmaster.dto.connections import HiveConnectionDTO | ||
from syncmaster.dto.transfers import HiveTransferDTO | ||
from syncmaster.worker.handlers.db.base import DBHandler | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.dataframe import DataFrame | ||
|
||
|
||
class HiveHandler(DBHandler): | ||
connection: Hive | ||
connection_dto: HiveConnectionDTO | ||
transfer_dto: HiveTransferDTO | ||
|
||
def connect(self, spark: SparkSession): | ||
self.connection = Hive( | ||
cluster=self.connection_dto.cluster, | ||
spark=spark, | ||
).check() | ||
|
||
def read(self) -> DataFrame: | ||
self.connection.spark.catalog.refreshTable(self.transfer_dto.table_name) | ||
return super().read() | ||
|
||
def normalize_column_names(self, df: DataFrame) -> DataFrame: | ||
for column_name in df.columns: | ||
df = df.withColumnRenamed(column_name, column_name.lower()) | ||
return df |
Oops, something went wrong.