Skip to content

Commit

Permalink
Add some more type annotations to clients.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mfussenegger committed Sep 9, 2024
1 parent d910d77 commit 86d7107
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
25 changes: 14 additions & 11 deletions cr8/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
import time
from urllib.parse import urlparse, parse_qs, urlunparse
from datetime import datetime, date
from typing import List, Union, Iterable, Dict
from typing import List, Union, Iterable, Dict, Optional, Any
from decimal import Decimal
from cr8.aio import asyncio # import via aio for uvloop setup

try:
import asyncpg
except ImportError:
asyncpg = None
asyncpg = None # type: ignore

try:
import simdjson
import simdjson # type: ignore
dumps = simdjson.dumps
except ImportError:
dumps = json.dumps
Expand Down Expand Up @@ -137,7 +137,7 @@ def _plain_or_callable(obj):
return obj


def _date_or_none(d: str) -> str:
def _date_or_none(d: str) -> Optional[str]:
"""Return a date as if, if valid, otherwise None
>>> _date_or_none('2017-02-27')
Expand All @@ -152,7 +152,7 @@ def _date_or_none(d: str) -> str:
return None


def _to_dsn(hosts):
def _to_dsn(hosts: str) -> str:
"""Convert a host URI into a dsn for aiopg.
>>> _to_dsn('aiopg://myhostname:4242/mydb')
Expand All @@ -177,15 +177,15 @@ def _to_dsn(hosts):
host, port = netloc.split(':', maxsplit=1)
except ValueError:
host = netloc
port = 5432
port = "5432"
dbname = p.path[1:] if p.path else 'doc'
dsn = f'postgres://{user_and_pw}@{host}:{port}/{dbname}'
if p.query:
dsn += '?' + '&'.join(k + '=' + v[0] for k, v in parse_qs(p.query).items())
return dsn


def _to_boolean(v):
def _to_boolean(v: str) -> bool:
if str(v).lower() in ("true"):
return True
elif str(v).lower() in ("false"):
Expand All @@ -194,7 +194,7 @@ def _to_boolean(v):
raise ValueError('not a boolean value')


def _verify_ssl_from_first(hosts):
def _verify_ssl_from_first(hosts: List[str]) -> bool:
"""Check if SSL validation parameter is passed in URI
>>> _verify_ssl_from_first(['https://myhost:4200/?verify_ssl=false'])
Expand Down Expand Up @@ -295,7 +295,7 @@ def __exit__(self, *exs):
self.close()


def _append_sql(host):
def _append_sql(host: str) -> str:
""" Append `/_sql` to the host, dropping any query parameters.
>>> _append_sql('http://n1:4200')
Expand All @@ -316,12 +316,15 @@ def _append_sql(host):


class HttpClient:
def __init__(self, hosts, conn_pool_limit=25, session_settings=None):
def __init__(self,
hosts: List[str],
conn_pool_limit=25,
session_settings: Optional[Dict[str, Any]]=None):
self.hosts = hosts
self.urls = itertools.cycle(list(map(_append_sql, hosts)))
self.conn_pool_limit = conn_pool_limit
self.is_cratedb = True
self._pools = {}
self._pools: Dict[str, asyncio.Queue] = {}
self.session_settings = session_settings or {}

async def _session(self, url):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
'asyncpg'
],
extras_require={
'extra': ['uvloop', 'pysimdjson']
'extra': ['uvloop', 'pysimdjson'],
"dev": ["asyncpg-stubs", "mypy"]
},
python_requires='>=3.7',
classifiers=[
Expand Down

0 comments on commit 86d7107

Please sign in to comment.