Skip to content

feat: support kwargs for pg conn in pg_listener #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,34 @@
pg_pub_sub

Arguments:
# cSpell:ignore libpq
---------
dsn: The connection string/dsn for Postgres
channels: The list of channels to listen

Example:
dsn: Optional, the connection string/dsn for Postgres as supported by psycopg/libpq
refer to https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-KEYWORD-VALUE
Either dsn or postgres_params is required
postgres_params: Optional, dict, the parameters for the pg connection as they are supported by psycopg/libpq
refer to https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
If the param is already in the dsn, it will be overridden by the value in postgres_params
Either dsn or postgres_params is required
channels: Required, the list of channels to listen

Examples:
-------
- ansible.eda.pg_listener:
dsn: "host=localhost port=5432 dbname=mydb"
channels:
- my_events
- my_alerts

- ansible.eda.pg_listener:
postgres_params:
host: localhost
port: 5432
dbname: mydb
channels:
- my_events
- my_alerts

Chunking:
---------
This is just informational a user doesn't have to do anything
Expand Down Expand Up @@ -55,7 +71,7 @@
MESSAGE_CHUNK = "_chunk"
MESSAGE_LENGTH = "_message_length"
MESSAGE_XX_HASH = "_message_xx_hash"
REQUIRED_KEYS = ("dsn", "channels")
REQUIRED_KEYS = ["channels"]

REQUIRED_CHUNK_KEYS = (
MESSAGE_CHUNK_COUNT,
Expand All @@ -69,10 +85,6 @@
class MissingRequiredArgumentError(Exception):
"""Exception class for missing arguments."""

def __init__(self: "MissingRequiredArgumentError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"PG Listener {key} is a required argument")


class MissingChunkKeyError(Exception):
"""Exception class for missing chunking keys."""
Expand All @@ -88,16 +100,38 @@ def _validate_chunked_payload(payload: dict[str, Any]) -> None:
raise MissingChunkKeyError(key)


def _validate_args(args: dict[str, Any]) -> None:
"""Validate the arguments and raise exception accordingly."""
missing_keys = [key for key in REQUIRED_KEYS if key not in args]
if missing_keys:
msg = f"Missing required arguments: {', '.join(missing_keys)}"
raise MissingRequiredArgumentError(msg)
if args.get("dsn") is None and args.get("postgres_params") is None:
msg = "Missing dsn or postgres_params, at least one is required"
raise MissingRequiredArgumentError(msg)

# Type checking
# TODO(alejandro): We should implement a standard way to validate the schema
# of the arguments for all the plugins
if not isinstance(args["channels"], list) or not args["channels"]:
raise ValueError("Channels must be a list and not empty")
if args.get("dsn") is not None and not isinstance(args["dsn"], str):
raise ValueError("DSN must be a string")
if args.get("postgres_params") is not None and not isinstance(
args["postgres_params"], dict
):
raise ValueError("Postgres params must be a dictionary")


async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
raise MissingRequiredArgumentError(key)
_validate_args(args)

try:
async with await AsyncConnection.connect(
conninfo=args["dsn"],
conninfo=args.get("dsn", ""),
autocommit=True,
**args.get("postgres_params", {}),
) as conn:
chunked_cache: dict[str, Any] = {}
cursor = conn.cursor()
Expand Down
134 changes: 133 additions & 1 deletion tests/unit/event_source/test_pg_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import json
import uuid
from typing import Any
from typing import Any, Type
from unittest.mock import AsyncMock, MagicMock, patch

import psycopg
Expand All @@ -17,6 +17,8 @@
MESSAGE_CHUNKED_UUID,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
MissingRequiredArgumentError,
_validate_args,
)
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main

Expand Down Expand Up @@ -180,3 +182,133 @@ def my_iterator() -> _AsyncIterator:
},
)
)


def test_validate_args_with_missing_keys() -> None:
"""Test missing required arguments."""
args: dict[str, str] = {}
with pytest.raises(MissingRequiredArgumentError) as exc:
_validate_args(args)
assert str(exc.value) == "Missing required arguments: channels"


def test_validate_args_with_missing_dsn_and_postgres_params() -> None:
"""Test missing dsn and postgres_params."""
args = {"channels": ["test"]}
with pytest.raises(MissingRequiredArgumentError) as exc:
_validate_args(args)
assert str(exc.value) == "Missing dsn or postgres_params, at least one is required"


def test_validate_args_with_missing_dsn() -> None:
"""Test missing dsn."""
args = {
"postgres_params": {"user": "postgres", "password": "password"},
"channels": ["test"],
}
with (
patch(
"extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS",
["dsn"],
),
pytest.raises(MissingRequiredArgumentError) as exc,
):
_validate_args(args)
assert str(exc.value) == "Missing required arguments: dsn"


def test_validate_args_with_missing_postgres_params() -> None:
"""Test missing postgres_params."""
args = {
"dsn": "host=localhost dbname=mydb user=postgres password=password",
"channels": ["test"],
}
with (
patch(
"extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS",
["postgres_params"],
),
pytest.raises(MissingRequiredArgumentError) as exc,
):
_validate_args(args)
assert str(exc.value) == "Missing required arguments: postgres_params"


def test_validate_args_with_valid_args() -> None:
"""Test valid arguments."""
args = {
"dsn": "host=localhost dbname=mydb user=postgres password=password",
"channels": ["test"],
}
_validate_args(args) # No exception should be raised


@pytest.mark.parametrize(
"args, expected_exception, expected_message",
[
# Valid channels
({"channels": ["channel1", "channel2"], "dsn": "dummy"}, None, None),
# Empty channels
(
{"channels": [], "dsn": "dummy"},
ValueError,
"Channels must be a list and not empty",
),
# Non-list channels
(
{"channels": "channel1", "dsn": "dummy"},
ValueError,
"Channels must be a list and not empty",
),
# Valid dsn
(
{
"channels": ["channel1"],
"dsn": "postgres://user:password@host:port/database",
},
None,
None,
),
# Invalid dsn
(
{"channels": ["channel1"], "dsn": 123},
ValueError,
"DSN must be a string",
),
# Valid postgres params
(
{
"channels": ["channel1"],
"postgres_params": {"host": "localhost", "port": 5432},
},
None,
None,
),
# Invalid postgres params
(
{"channels": ["channel1"], "postgres_params": "invalid_params"},
ValueError,
"Postgres params must be a dictionary",
),
# Invalid postgres params
(
{
"channels": ["channel1"],
"postgres_params": [{"host": "localhost"}, {"port": "5432"}],
},
ValueError,
"Postgres params must be a dictionary",
),
],
)
def test_validate_args_type_checks(
args: dict[str, Any],
expected_exception: Type[Exception],
expected_message: str,
) -> None:
"""Test _validate_args type checks."""
if expected_exception is None:
_validate_args(args)
else:
with pytest.raises(expected_exception, match=expected_message):
_validate_args(args)