diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..da95e4f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,110 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +**Testing:** +```bash +# Run all tests +pytest -vv --cov=hasql --cov-report=term-missing --doctest-modules --aiomisc-test-timeout=30 tests + +# Run specific test file +pytest -vv tests/test_utils.py + +# Run specific test +pytest -vv tests/test_utils.py::test_parse_connection_string_basic + +# Run tests with specific pattern +pytest -vv tests/test_utils.py -k "connection_string" + +# Run tests using tox (preferred) +tox -e py39 # Python 3.9 +tox -e py310 # Python 3.10 +tox -e py311 # Python 3.11 +``` + +**Linting and Type Checking:** +```bash +# Lint code +pylama -o pylama.ini hasql tests + +# Type checking +mypy --install-types --non-interactive hasql tests + +# Using tox (preferred) +tox -e lint +tox -e mypy +``` + +**Package Installation:** +```bash +# Install development dependencies +pip install -e ".[develop]" + +# Install specific extras +pip install -e ".[aiopg]" # aiopg support +pip install -e ".[asyncpg]" # asyncpg support +pip install -e ".[psycopg]" # psycopg3 support +``` + +## Architecture Overview + +**hasql** is a high-availability PostgreSQL connection management library that provides automatic master/replica detection and load balancing across multiple database hosts. + +### Core Components + +1. **BasePoolManager** (`hasql/base.py`) - Abstract base class that defines the core pooling interface and connection management logic + +2. **Driver-Specific Pool Managers:** + - `hasql.aiopg.PoolManager` - aiopg driver support + - `hasql.asyncpg.PoolManager` - asyncpg driver support + - `hasql.psycopg3.PoolManager` - psycopg3 driver support + - `hasql.asyncsqlalchemy.PoolManager` - SQLAlchemy async support + - `hasql.aiopg_sa.PoolManager` - aiopg with SQLAlchemy support + +3. **Balancer Policies** (`hasql/balancer_policy/`) - Load balancing strategies: + - `GreedyBalancerPolicy` - Chooses pool with most free connections + - `RandomWeightedBalancerPolicy` - Weighted random selection based on response times + - `RoundRobinBalancerPolicy` - Round-robin selection + +4. **Connection String Parsing** (`hasql/utils.py`) - Handles multi-host PostgreSQL connection strings with support for: + - Comma-separated hosts: `postgresql://db1,db2,db3/dbname` + - Per-host ports: `postgresql://db1:1234,db2:5678/dbname` + - Global port override: `postgresql://db1,db2:6432/dbname` + - libpq-style connection strings + +5. **Metrics and Monitoring** (`hasql/metrics.py`) - Connection and performance metrics collection + +### Key Features + +- **Automatic Role Detection:** Continuously monitors each host to determine if it's a master or replica +- **Health Monitoring:** Background tasks check host availability and automatically exclude unhealthy hosts +- **Load Balancing:** Multiple policies for distributing connections across healthy replicas +- **Failover Support:** Automatic fallback to master when replicas are unavailable +- **Multi-Driver Support:** Works with asyncpg, aiopg, psycopg3, and SQLAlchemy + +### Connection Flow + +1. Parse multi-host DSN string into individual host connections +2. Create connection pools for each host with reserved system connections +3. Background tasks continuously check each host's role (master/replica) and health +4. When acquiring connections, balancer selects appropriate pool based on read_only flag +5. Connections are automatically returned to their respective pools when released + +### Testing Strategy + +- Uses pytest with aiomisc test framework +- Mocks database connections for unit testing (`tests/mocks/`) +- Integration tests for each driver implementation +- Coverage reporting with pytest-cov +- Tests are organized by driver type and functionality + +## Important Notes + +- The codebase uses Python 3.8+ with async/await throughout +- All pool managers extend the abstract `BasePoolManager` class +- Connection strings support both single and multi-host PostgreSQL URLs +- Background health checking runs every `refresh_delay` seconds (default: 1s) +- System reserves one connection per pool for health monitoring +- The library automatically detects PostgreSQL role changes with slight delay \ No newline at end of file diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index 1d117d4..7b5e281 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -14,7 +14,7 @@ class PoolManager(BasePoolManager): def get_pool_freesize(self, pool: AsyncEngine): - queue_pool: QueuePool = pool.sync_engine.pool # type: ignore + queue_pool: QueuePool = pool.sync_engine.pool return queue_pool.size() - queue_pool.checkedout() def acquire_from_pool(self, pool: AsyncEngine, **kwargs): diff --git a/hasql/utils.py b/hasql/utils.py index 7ace79a..491fc15 100644 --- a/hasql/utils.py +++ b/hasql/utils.py @@ -50,28 +50,119 @@ def __init__( @classmethod def parse(cls, dsn: str) -> "Dsn": + # First try URL format match = cls.URL_EXP.match(dsn) - if match is None: - raise ValueError("Bad DSN") - - groupdict = match.groupdict() - scheme = groupdict["scheme"] - user = groupdict.get("user") - password = groupdict.get("password") - netloc: str = groupdict["netloc"] - dbname = (groupdict.get("path") or "").lstrip("/") - query = groupdict.get("query") or "" + if match is not None: + # URL format parsing + groupdict = match.groupdict() + scheme = groupdict["scheme"] + user = groupdict.get("user") + password = groupdict.get("password") + netloc: str = groupdict["netloc"] + dbname = (groupdict.get("path") or "").lstrip("/") + query = groupdict.get("query") or "" + + params = {} + for item in query.split("&"): + if not item: + continue + key, value = item.split("=", 1) + params[key] = unquote(value) + + return cls( + scheme=scheme, + netloc=netloc, + user=user, + password=password, + dbname=dbname, + **params + ) + + # Try connection string format: + # 'host=localhost,localhost port=5432,5432 dbname=mydb' + return cls._parse_connection_string(dsn) + @classmethod + def _parse_connection_string_params(cls, conn_str: str) -> Dict[str, str]: + """Parse key=value pairs from connection string.""" params = {} - for item in query.split("&"): - if not item: - continue - key, value = item.split("=", 1) - params[key] = unquote(value) + current_key = None + current_value = "" + in_quotes = False + quote_char = None + + i = 0 + while i < len(conn_str): + char = conn_str[i] + + if not in_quotes: + if char in ("'", '"'): + in_quotes = True + quote_char = char + elif char == '=' and current_key is None: + # Found key=value separator + current_key = current_value.strip() + current_value = "" + elif char.isspace() and current_key is not None: + # End of value + params[current_key] = current_value.strip() + current_key = None + current_value = "" + else: + current_value += char + else: + if char == quote_char: + in_quotes = False + quote_char = None + else: + current_value += char + + i += 1 + + # Handle final key=value pair + if current_key is not None: + params[current_key] = current_value.strip() + + return params + + @classmethod + def _build_netloc(cls, hosts: str, ports: str) -> str: + """Build netloc from comma-separated hosts and ports.""" + host_list = [h.strip() for h in hosts.split(',')] + port_list = [p.strip() for p in ports.split(',')] + + # If single port, use it for all hosts + if len(port_list) == 1 and len(host_list) > 1: + port_list = port_list * len(host_list) + # If single host, use it for all ports + elif len(host_list) == 1 and len(port_list) > 1: + host_list = host_list * len(port_list) + + # Build netloc (use first host:port for the main DSN) + if len(host_list) > 0 and len(port_list) > 0: + return ','.join( + f"{host}:{port}" for host, port in zip(host_list, port_list) + ) + else: + return 'localhost:5432' + + @classmethod + def _parse_connection_string(cls, conn_str: str) -> "Dsn": + """Parse libpq-style connection string format.""" + params = cls._parse_connection_string_params(conn_str) + + # Extract standard connection parameters + hosts = params.pop('host', 'localhost') + ports = params.pop('port', '5432') + user = params.pop('user', None) + password = params.pop('password', None) + dbname = params.pop('dbname', None) + + netloc = cls._build_netloc(hosts, ports) return cls( - scheme=scheme, + scheme="postgresql", netloc=netloc, user=user, password=password, diff --git a/setup.cfg b/setup.cfg index 1763623..d07007e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,3 +15,6 @@ warn_unused_ignores = True [mypy-tests.*] ignore_errors = True + +[mypy-sqlalchemy.*] +ignore_missing_imports = True diff --git a/tests/test_utils.py b/tests/test_utils.py index ebc8605..90eeeff 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -253,3 +253,83 @@ def test_ipv6_host_in_dsn(): ) result_dsn, *_ = split_dsn(dsn) assert str(result_dsn) == dsn + + +# Connection string format tests +def test_parse_connection_string_basic(): + """Test basic connection string parsing.""" + conn_str = "host=localhost port=5432 dbname=mydb user=testuser" + dsn = Dsn.parse(conn_str) + assert dsn.netloc == "localhost:5432" + assert dsn.user == "testuser" + assert dsn.dbname == "mydb" + assert dsn.scheme == "postgresql" + + +def test_parse_connection_string_multiple_hosts(): + """Test connection string with comma-separated hosts.""" + conn_str = "host=localhost,replica port=5432,5433 dbname=mydb" + dsn = Dsn.parse(conn_str) + assert dsn.netloc == "localhost:5432,replica:5433" + assert dsn.dbname == "mydb" + + +def test_parse_connection_string_single_port_multiple_hosts(): + """Test connection string with single port for multiple hosts.""" + conn_str = "host=localhost,replica port=5432 dbname=mydb" + dsn = Dsn.parse(conn_str) + assert dsn.netloc == "localhost:5432,replica:5432" + + +def test_parse_connection_string_with_password(): + """Test connection string with password.""" + conn_str = ( + "host=localhost port=5432 dbname=mydb user=testuser password=secret" + ) + dsn = Dsn.parse(conn_str) + assert dsn.user == "testuser" + assert dsn.password == "secret" + + +def test_parse_connection_string_with_extra_params(): + """Test connection string with additional parameters.""" + conn_str = ( + "host=localhost port=5432 dbname=mydb " + "connect_timeout=10 sslmode=require" + ) + dsn = Dsn.parse(conn_str) + assert dsn.params["connect_timeout"] == "10" + assert dsn.params["sslmode"] == "require" + + +def test_parse_connection_string_quoted_values(): + """Test connection string with quoted values.""" + conn_str = "host=localhost port=5432 dbname='my database' user='test user'" + dsn = Dsn.parse(conn_str) + assert dsn.dbname == "my database" + assert dsn.user == "test user" + + +def test_split_dsn_from_connection_string(): + """Test that split_dsn works with connection string format.""" + conn_str = "host=localhost,replica port=5432,5433 dbname=mydb user=testuser" + dsns = split_dsn(conn_str) + assert len(dsns) == 2 + assert str(dsns[0]) == "postgresql://testuser@localhost:5432/mydb" + assert str(dsns[1]) == "postgresql://testuser@replica:5433/mydb" + + +def test_connection_string_example_format(): + """Test the exact example format from the user request.""" + conn_str = ( + "host=localhost,localhost port=5432,5432 dbname=mydb connect_timeout=10" + ) + dsn = Dsn.parse(conn_str) + assert dsn.netloc == "localhost:5432,localhost:5432" + assert dsn.dbname == "mydb" + assert dsn.params["connect_timeout"] == "10" + + # Test that split_dsn handles duplicates correctly + dsns = split_dsn(conn_str) + assert len(dsns) == 1 # Should deduplicate identical host:port pairs + assert str(dsns[0]) == "postgresql://localhost:5432/mydb?connect_timeout=10"