Skip to content

Commit

Permalink
refactor: init_replication is now in pkg ns
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas ESTRADA committed Jan 22, 2025
1 parent d695afb commit 8f45283
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 81 deletions.
79 changes: 76 additions & 3 deletions sources/pg_legacy_replication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
"""Replicates postgres tables in batch using logical decoding."""

from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
from collections import defaultdict
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union

import dlt
from dlt.extract import DltResource
from dlt.extract.items import TDataItem
from dlt.sources.credentials import ConnectionStringCredentials
from collections import defaultdict
from dlt.sources.sql_database import sql_table

from .helpers import (
BackendHandler,
ItemGenerator,
ReplicationOptions,
SqlTableOptions,
advance_slot,
cleanup_snapshot_resources,
configure_engine,
create_replication_slot,
drop_replication_slot,
get_max_lsn,
init_replication,
get_rep_conn,
)


Expand Down Expand Up @@ -132,6 +137,74 @@ def _create_table_dispatch(
return handler


@dlt.source
def init_replication(
slot_name: str,
schema: str,
table_names: Optional[Union[str, Sequence[str]]] = None,
credentials: ConnectionStringCredentials = dlt.secrets.value,
take_snapshots: bool = False,
table_options: Optional[Mapping[str, SqlTableOptions]] = None,
reset: bool = False,
) -> Iterable[DltResource]:
"""
Initializes a replication session for Postgres using logical replication.
Optionally, snapshots of specified tables can be taken during initialization.
Args:
slot_name (str):
The name of the logical replication slot to be used or created.
schema (str):
Name of the schema to replicate tables from.
table_names (Optional[Union[str, Sequence[str]]]):
The name(s) of the table(s) to replicate. Can be a single table name or a list of table names.
If not provided, no tables will be replicated unless `take_snapshots` is `True`.
credentials (ConnectionStringCredentials):
Database credentials for connecting to the Postgres instance.
take_snapshots (bool):
Whether to take initial snapshots of the specified tables.
Defaults to `False`.
table_options (Optional[Mapping[str, SqlTableOptions]]):
Additional options for configuring replication for specific tables.
These are the exact same parameters for the `dlt.sources.sql_database.sql_table` function.
Argument is only used if `take_snapshots` is `True`.
reset (bool, optional):
If `True`, drops the existing replication slot before creating a new one.
Use with caution, as this will clear existing replication state.
Defaults to `False`.
Returns:
- None if `take_snapshots` is `False`
- a list of `DltResource` objects for the snapshot table(s) if `take_snapshots` is `True`.
Notes:
- If `reset` is `True`, the existing replication slot will be dropped before creating a new one.
- When `take_snapshots` is `True`, the function configures a snapshot isolation level for consistent table snapshots.
"""
rep_conn = get_rep_conn(credentials)
rep_cur = rep_conn.cursor()
if reset:
drop_replication_slot(slot_name, rep_cur)
slot = create_replication_slot(slot_name, rep_cur)

# Close connection if no snapshots are needed
if not take_snapshots:
rep_conn.close()
return

assert table_names is not None

engine = configure_engine(
credentials, rep_conn, slot.get("snapshot_name") if slot else None
)

table_names = [table_names] if isinstance(table_names, str) else table_names or []

for table in table_names:
table_args = (table_options or {}).get(table, {}).copy()
yield sql_table(credentials=engine, table=table, schema=schema, **table_args)


__all__ = [
"ReplicationOptions",
"cleanup_snapshot_resources",
Expand Down
79 changes: 4 additions & 75 deletions sources/pg_legacy_replication/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Sequence,
Set,
TypedDict,
Union,
)

import dlt
Expand All @@ -26,7 +25,7 @@
from dlt.common.schema.typing import TColumnSchema, TTableSchema, TTableSchemaColumns
from dlt.common.schema.utils import merge_column
from dlt.common.typing import TDataItem
from dlt.extract import DltResource, DltSource
from dlt.extract import DltSource
from dlt.extract.items import DataItemWithMeta
from dlt.sources.credentials import ConnectionStringCredentials
from dlt.sources.sql_database import (
Expand All @@ -36,7 +35,6 @@
TTypeAdapter,
arrow_helpers as arrow,
engine_from_credentials,
sql_table,
)
from psycopg2.extensions import connection as ConnectionExt, cursor
from psycopg2.extras import (
Expand Down Expand Up @@ -75,76 +73,7 @@ class SqlTableOptions(TypedDict, total=False):
type_adapter_callback: Optional[TTypeAdapter]


@dlt.sources.config.with_config(sections=("sources", "pg_legacy_replication"))
@dlt.source
def init_replication(
slot_name: str,
schema: str,
table_names: Optional[Union[str, Sequence[str]]] = None,
credentials: ConnectionStringCredentials = dlt.secrets.value,
take_snapshots: bool = False,
table_options: Optional[Mapping[str, SqlTableOptions]] = None,
reset: bool = False,
) -> Iterable[DltResource]:
"""
Initializes a replication session for Postgres using logical replication.
Optionally, snapshots of specified tables can be taken during initialization.
Args:
slot_name (str):
The name of the logical replication slot to be used or created.
schema (str):
Name of the schema to replicate tables from.
table_names (Optional[Union[str, Sequence[str]]]):
The name(s) of the table(s) to replicate. Can be a single table name or a list of table names.
If not provided, no tables will be replicated unless `take_snapshots` is `True`.
credentials (ConnectionStringCredentials):
Database credentials for connecting to the Postgres instance.
take_snapshots (bool):
Whether to take initial snapshots of the specified tables.
Defaults to `False`.
table_options (Optional[Mapping[str, SqlTableOptions]]):
Additional options for configuring replication for specific tables.
These are the exact same parameters for the `dlt.sources.sql_database.sql_table` function.
Argument is only used if `take_snapshots` is `True`.
reset (bool, optional):
If `True`, drops the existing replication slot before creating a new one.
Use with caution, as this will clear existing replication state.
Defaults to `False`.
Returns:
- None if `take_snapshots` is `False`
- a list of `DltResource` objects for the snapshot table(s) if `take_snapshots` is `True`.
Notes:
- If `reset` is `True`, the existing replication slot will be dropped before creating a new one.
- When `take_snapshots` is `True`, the function configures a snapshot isolation level for consistent table snapshots.
"""
rep_conn = _get_rep_conn(credentials)
rep_cur = rep_conn.cursor()
if reset:
drop_replication_slot(slot_name, rep_cur)
slot = create_replication_slot(slot_name, rep_cur)

# Close connection if no snapshots are needed
if not take_snapshots:
rep_conn.close()
return

assert table_names is not None

engine = _configure_engine(
credentials, rep_conn, slot.get("snapshot_name") if slot else None
)

table_names = [table_names] if isinstance(table_names, str) else table_names or []

for table in table_names:
table_args = (table_options or {}).get(table, {}).copy()
yield sql_table(credentials=engine, table=table, schema=schema, **table_args)


def _configure_engine(
def configure_engine(
credentials: ConnectionStringCredentials,
rep_conn: LogicalReplicationConnection,
snapshot_name: Optional[str],
Expand Down Expand Up @@ -288,7 +217,7 @@ def _get_conn(
)


def _get_rep_conn(
def get_rep_conn(
credentials: ConnectionStringCredentials,
) -> LogicalReplicationConnection:
"""
Expand Down Expand Up @@ -453,7 +382,7 @@ def __iter__(self) -> Iterator[TableItems]:
Maintains LSN of last consumed commit message in object state.
Advances the slot only when all messages have been consumed.
"""
with _get_rep_conn(self.credentials) as conn:
with get_rep_conn(self.credentials) as conn:
cur = conn.cursor()
cur.start_replication(slot_name=self.slot_name, start_lsn=self.start_lsn)
consumer = MessageConsumer(
Expand Down
2 changes: 1 addition & 1 deletion sources/pg_legacy_replication/schema_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from functools import lru_cache
from typing import Any, Callable, List, Dict, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import pendulum
from dlt.common import Decimal, logger
Expand Down
3 changes: 1 addition & 2 deletions sources/pg_legacy_replication_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from dlt.common.destination import Destination
from dlt.destinations.impl.postgres.configuration import PostgresCredentials

from pg_legacy_replication import replication_source
from pg_legacy_replication.helpers import init_replication
from pg_legacy_replication import init_replication, replication_source

PG_CREDS = dlt.secrets.get("sources.pg_replication.credentials", PostgresCredentials)

Expand Down

0 comments on commit 8f45283

Please sign in to comment.