Skip to content

Commit 64a32f1

Browse files
feat: support kwarg for pg conn in pg_listener (#383)
1 parent ced1116 commit 64a32f1

File tree

2 files changed

+180
-14
lines changed

2 files changed

+180
-14
lines changed

extensions/eda/plugins/event_source/pg_listener.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,34 @@
44
pg_pub_sub
55
66
Arguments:
7+
# cSpell:ignore libpq
78
---------
8-
dsn: The connection string/dsn for Postgres
9-
channels: The list of channels to listen
10-
11-
Example:
9+
dsn: Optional, the connection string/dsn for Postgres as supported by psycopg/libpq
10+
refer to https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-KEYWORD-VALUE
11+
Either dsn or postgres_params is required
12+
postgres_params: Optional, dict, the parameters for the pg connection as they are supported by psycopg/libpq
13+
refer to https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
14+
If the param is already in the dsn, it will be overridden by the value in postgres_params
15+
Either dsn or postgres_params is required
16+
channels: Required, the list of channels to listen
17+
18+
Examples:
1219
-------
1320
- ansible.eda.pg_listener:
1421
dsn: "host=localhost port=5432 dbname=mydb"
1522
channels:
1623
- my_events
1724
- my_alerts
1825
26+
- ansible.eda.pg_listener:
27+
postgres_params:
28+
host: localhost
29+
port: 5432
30+
dbname: mydb
31+
channels:
32+
- my_events
33+
- my_alerts
34+
1935
Chunking:
2036
---------
2137
This is just informational a user doesn't have to do anything
@@ -55,7 +71,7 @@
5571
MESSAGE_CHUNK = "_chunk"
5672
MESSAGE_LENGTH = "_message_length"
5773
MESSAGE_XX_HASH = "_message_xx_hash"
58-
REQUIRED_KEYS = ("dsn", "channels")
74+
REQUIRED_KEYS = ["channels"]
5975

6076
REQUIRED_CHUNK_KEYS = (
6177
MESSAGE_CHUNK_COUNT,
@@ -69,10 +85,6 @@
6985
class MissingRequiredArgumentError(Exception):
7086
"""Exception class for missing arguments."""
7187

72-
def __init__(self: "MissingRequiredArgumentError", key: str) -> None:
73-
"""Class constructor with the missing key."""
74-
super().__init__(f"PG Listener {key} is a required argument")
75-
7688

7789
class MissingChunkKeyError(Exception):
7890
"""Exception class for missing chunking keys."""
@@ -88,16 +100,38 @@ def _validate_chunked_payload(payload: dict[str, Any]) -> None:
88100
raise MissingChunkKeyError(key)
89101

90102

103+
def _validate_args(args: dict[str, Any]) -> None:
104+
"""Validate the arguments and raise exception accordingly."""
105+
missing_keys = [key for key in REQUIRED_KEYS if key not in args]
106+
if missing_keys:
107+
msg = f"Missing required arguments: {', '.join(missing_keys)}"
108+
raise MissingRequiredArgumentError(msg)
109+
if args.get("dsn") is None and args.get("postgres_params") is None:
110+
msg = "Missing dsn or postgres_params, at least one is required"
111+
raise MissingRequiredArgumentError(msg)
112+
113+
# Type checking
114+
# TODO(alejandro): We should implement a standard way to validate the schema
115+
# of the arguments for all the plugins
116+
if not isinstance(args["channels"], list) or not args["channels"]:
117+
raise ValueError("Channels must be a list and not empty")
118+
if args.get("dsn") is not None and not isinstance(args["dsn"], str):
119+
raise ValueError("DSN must be a string")
120+
if args.get("postgres_params") is not None and not isinstance(
121+
args["postgres_params"], dict
122+
):
123+
raise ValueError("Postgres params must be a dictionary")
124+
125+
91126
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
92127
"""Listen for events from a channel."""
93-
for key in REQUIRED_KEYS:
94-
if key not in args:
95-
raise MissingRequiredArgumentError(key)
128+
_validate_args(args)
96129

97130
try:
98131
async with await AsyncConnection.connect(
99-
conninfo=args["dsn"],
132+
conninfo=args.get("dsn", ""),
100133
autocommit=True,
134+
**args.get("postgres_params", {}),
101135
) as conn:
102136
chunked_cache: dict[str, Any] = {}
103137
cursor = conn.cursor()

tests/unit/event_source/test_pg_listener.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
import uuid
6-
from typing import Any
6+
from typing import Any, Type
77
from unittest.mock import AsyncMock, MagicMock, patch
88

99
import psycopg
@@ -17,6 +17,8 @@
1717
MESSAGE_CHUNKED_UUID,
1818
MESSAGE_LENGTH,
1919
MESSAGE_XX_HASH,
20+
MissingRequiredArgumentError,
21+
_validate_args,
2022
)
2123
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main
2224

@@ -180,3 +182,133 @@ def my_iterator() -> _AsyncIterator:
180182
},
181183
)
182184
)
185+
186+
187+
def test_validate_args_with_missing_keys() -> None:
188+
"""Test missing required arguments."""
189+
args: dict[str, str] = {}
190+
with pytest.raises(MissingRequiredArgumentError) as exc:
191+
_validate_args(args)
192+
assert str(exc.value) == "Missing required arguments: channels"
193+
194+
195+
def test_validate_args_with_missing_dsn_and_postgres_params() -> None:
196+
"""Test missing dsn and postgres_params."""
197+
args = {"channels": ["test"]}
198+
with pytest.raises(MissingRequiredArgumentError) as exc:
199+
_validate_args(args)
200+
assert str(exc.value) == "Missing dsn or postgres_params, at least one is required"
201+
202+
203+
def test_validate_args_with_missing_dsn() -> None:
204+
"""Test missing dsn."""
205+
args = {
206+
"postgres_params": {"user": "postgres", "password": "password"},
207+
"channels": ["test"],
208+
}
209+
with (
210+
patch(
211+
"extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS",
212+
["dsn"],
213+
),
214+
pytest.raises(MissingRequiredArgumentError) as exc,
215+
):
216+
_validate_args(args)
217+
assert str(exc.value) == "Missing required arguments: dsn"
218+
219+
220+
def test_validate_args_with_missing_postgres_params() -> None:
221+
"""Test missing postgres_params."""
222+
args = {
223+
"dsn": "host=localhost dbname=mydb user=postgres password=password",
224+
"channels": ["test"],
225+
}
226+
with (
227+
patch(
228+
"extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS",
229+
["postgres_params"],
230+
),
231+
pytest.raises(MissingRequiredArgumentError) as exc,
232+
):
233+
_validate_args(args)
234+
assert str(exc.value) == "Missing required arguments: postgres_params"
235+
236+
237+
def test_validate_args_with_valid_args() -> None:
238+
"""Test valid arguments."""
239+
args = {
240+
"dsn": "host=localhost dbname=mydb user=postgres password=password",
241+
"channels": ["test"],
242+
}
243+
_validate_args(args) # No exception should be raised
244+
245+
246+
@pytest.mark.parametrize(
247+
"args, expected_exception, expected_message",
248+
[
249+
# Valid channels
250+
({"channels": ["channel1", "channel2"], "dsn": "dummy"}, None, None),
251+
# Empty channels
252+
(
253+
{"channels": [], "dsn": "dummy"},
254+
ValueError,
255+
"Channels must be a list and not empty",
256+
),
257+
# Non-list channels
258+
(
259+
{"channels": "channel1", "dsn": "dummy"},
260+
ValueError,
261+
"Channels must be a list and not empty",
262+
),
263+
# Valid dsn
264+
(
265+
{
266+
"channels": ["channel1"],
267+
"dsn": "postgres://user:password@host:port/database",
268+
},
269+
None,
270+
None,
271+
),
272+
# Invalid dsn
273+
(
274+
{"channels": ["channel1"], "dsn": 123},
275+
ValueError,
276+
"DSN must be a string",
277+
),
278+
# Valid postgres params
279+
(
280+
{
281+
"channels": ["channel1"],
282+
"postgres_params": {"host": "localhost", "port": 5432},
283+
},
284+
None,
285+
None,
286+
),
287+
# Invalid postgres params
288+
(
289+
{"channels": ["channel1"], "postgres_params": "invalid_params"},
290+
ValueError,
291+
"Postgres params must be a dictionary",
292+
),
293+
# Invalid postgres params
294+
(
295+
{
296+
"channels": ["channel1"],
297+
"postgres_params": [{"host": "localhost"}, {"port": "5432"}],
298+
},
299+
ValueError,
300+
"Postgres params must be a dictionary",
301+
),
302+
],
303+
)
304+
def test_validate_args_type_checks(
305+
args: dict[str, Any],
306+
expected_exception: Type[Exception],
307+
expected_message: str,
308+
) -> None:
309+
"""Test _validate_args type checks."""
310+
if expected_exception is None:
311+
_validate_args(args)
312+
else:
313+
with pytest.raises(expected_exception, match=expected_message):
314+
_validate_args(args)

0 commit comments

Comments
 (0)