|
3 | 3 | import asyncio
|
4 | 4 | import json
|
5 | 5 | import uuid
|
6 |
| -from typing import Any |
| 6 | +from typing import Any, Type |
7 | 7 | from unittest.mock import AsyncMock, MagicMock, patch
|
8 | 8 |
|
9 | 9 | import psycopg
|
|
17 | 17 | MESSAGE_CHUNKED_UUID,
|
18 | 18 | MESSAGE_LENGTH,
|
19 | 19 | MESSAGE_XX_HASH,
|
| 20 | + MissingRequiredArgumentError, |
| 21 | + _validate_args, |
20 | 22 | )
|
21 | 23 | from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main
|
22 | 24 |
|
@@ -180,3 +182,133 @@ def my_iterator() -> _AsyncIterator:
|
180 | 182 | },
|
181 | 183 | )
|
182 | 184 | )
|
| 185 | + |
| 186 | + |
| 187 | +def test_validate_args_with_missing_keys() -> None: |
| 188 | + """Test missing required arguments.""" |
| 189 | + args: dict[str, str] = {} |
| 190 | + with pytest.raises(MissingRequiredArgumentError) as exc: |
| 191 | + _validate_args(args) |
| 192 | + assert str(exc.value) == "Missing required arguments: channels" |
| 193 | + |
| 194 | + |
| 195 | +def test_validate_args_with_missing_dsn_and_postgres_params() -> None: |
| 196 | + """Test missing dsn and postgres_params.""" |
| 197 | + args = {"channels": ["test"]} |
| 198 | + with pytest.raises(MissingRequiredArgumentError) as exc: |
| 199 | + _validate_args(args) |
| 200 | + assert str(exc.value) == "Missing dsn or postgres_params, at least one is required" |
| 201 | + |
| 202 | + |
| 203 | +def test_validate_args_with_missing_dsn() -> None: |
| 204 | + """Test missing dsn.""" |
| 205 | + args = { |
| 206 | + "postgres_params": {"user": "postgres", "password": "password"}, |
| 207 | + "channels": ["test"], |
| 208 | + } |
| 209 | + with ( |
| 210 | + patch( |
| 211 | + "extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS", |
| 212 | + ["dsn"], |
| 213 | + ), |
| 214 | + pytest.raises(MissingRequiredArgumentError) as exc, |
| 215 | + ): |
| 216 | + _validate_args(args) |
| 217 | + assert str(exc.value) == "Missing required arguments: dsn" |
| 218 | + |
| 219 | + |
| 220 | +def test_validate_args_with_missing_postgres_params() -> None: |
| 221 | + """Test missing postgres_params.""" |
| 222 | + args = { |
| 223 | + "dsn": "host=localhost dbname=mydb user=postgres password=password", |
| 224 | + "channels": ["test"], |
| 225 | + } |
| 226 | + with ( |
| 227 | + patch( |
| 228 | + "extensions.eda.plugins.event_source.pg_listener.REQUIRED_KEYS", |
| 229 | + ["postgres_params"], |
| 230 | + ), |
| 231 | + pytest.raises(MissingRequiredArgumentError) as exc, |
| 232 | + ): |
| 233 | + _validate_args(args) |
| 234 | + assert str(exc.value) == "Missing required arguments: postgres_params" |
| 235 | + |
| 236 | + |
| 237 | +def test_validate_args_with_valid_args() -> None: |
| 238 | + """Test valid arguments.""" |
| 239 | + args = { |
| 240 | + "dsn": "host=localhost dbname=mydb user=postgres password=password", |
| 241 | + "channels": ["test"], |
| 242 | + } |
| 243 | + _validate_args(args) # No exception should be raised |
| 244 | + |
| 245 | + |
| 246 | +@pytest.mark.parametrize( |
| 247 | + "args, expected_exception, expected_message", |
| 248 | + [ |
| 249 | + # Valid channels |
| 250 | + ({"channels": ["channel1", "channel2"], "dsn": "dummy"}, None, None), |
| 251 | + # Empty channels |
| 252 | + ( |
| 253 | + {"channels": [], "dsn": "dummy"}, |
| 254 | + ValueError, |
| 255 | + "Channels must be a list and not empty", |
| 256 | + ), |
| 257 | + # Non-list channels |
| 258 | + ( |
| 259 | + {"channels": "channel1", "dsn": "dummy"}, |
| 260 | + ValueError, |
| 261 | + "Channels must be a list and not empty", |
| 262 | + ), |
| 263 | + # Valid dsn |
| 264 | + ( |
| 265 | + { |
| 266 | + "channels": ["channel1"], |
| 267 | + "dsn": "postgres://user:password@host:port/database", |
| 268 | + }, |
| 269 | + None, |
| 270 | + None, |
| 271 | + ), |
| 272 | + # Invalid dsn |
| 273 | + ( |
| 274 | + {"channels": ["channel1"], "dsn": 123}, |
| 275 | + ValueError, |
| 276 | + "DSN must be a string", |
| 277 | + ), |
| 278 | + # Valid postgres params |
| 279 | + ( |
| 280 | + { |
| 281 | + "channels": ["channel1"], |
| 282 | + "postgres_params": {"host": "localhost", "port": 5432}, |
| 283 | + }, |
| 284 | + None, |
| 285 | + None, |
| 286 | + ), |
| 287 | + # Invalid postgres params |
| 288 | + ( |
| 289 | + {"channels": ["channel1"], "postgres_params": "invalid_params"}, |
| 290 | + ValueError, |
| 291 | + "Postgres params must be a dictionary", |
| 292 | + ), |
| 293 | + # Invalid postgres params |
| 294 | + ( |
| 295 | + { |
| 296 | + "channels": ["channel1"], |
| 297 | + "postgres_params": [{"host": "localhost"}, {"port": "5432"}], |
| 298 | + }, |
| 299 | + ValueError, |
| 300 | + "Postgres params must be a dictionary", |
| 301 | + ), |
| 302 | + ], |
| 303 | +) |
| 304 | +def test_validate_args_type_checks( |
| 305 | + args: dict[str, Any], |
| 306 | + expected_exception: Type[Exception], |
| 307 | + expected_message: str, |
| 308 | +) -> None: |
| 309 | + """Test _validate_args type checks.""" |
| 310 | + if expected_exception is None: |
| 311 | + _validate_args(args) |
| 312 | + else: |
| 313 | + with pytest.raises(expected_exception, match=expected_message): |
| 314 | + _validate_args(args) |
0 commit comments