Skip to content

Commit

Permalink
use a single global connection pool (#199)
Browse files Browse the repository at this point in the history
* use a single global connection pool

* actually don't make pools, even if multiple PGInterfaces are __init__ed before anyone calls connect_pg

* don't create a second pool while a first pool is connecting by sticking it in a class with a future

* fix self.pool reference

* fix close

* rename to OneTruePool and add an acquire method

* never mind, don't autoconnect on acquire

* change logging

* remove close_pools from tests
  • Loading branch information
technillogue authored Apr 29, 2022
1 parent edada47 commit 8059a15
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 34 deletions.
4 changes: 2 additions & 2 deletions forest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ async def async_shutdown(self, *_: Any, wait: bool = False) -> None:
logging.info(f"no {utils.SIGNAL} process")
if utils.UPLOAD:
await self.datastore.mark_freed()
await pghelp.close_pools()
await pghelp.pool.close()
# this still deadlocks. see https://github.com/forestcontact/forest-draft/issues/10
if autosave._memfs_process:
executor = autosave._memfs_process._get_executor()
Expand Down Expand Up @@ -721,7 +721,7 @@ async def log_activity(self) -> None:
if not self.seen_users:
continue
try:
async with self.activity.pool.acquire() as conn:
async with pghelp.pool.acquire() as conn:
# executemany batches this into an atomic db query
await conn.executemany(
self.activity.queries["log"],
Expand Down
79 changes: 49 additions & 30 deletions forest/pghelp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,56 @@ def get_logger(name: str) -> logging.Logger:
return logger


pools: list[asyncpg.Pool] = []
# this should be used for every insance


async def close_pools() -> None:
for pool in pools:
class OneTruePool:
connecting: Optional[asyncio.Future] = None
pool: Optional[asyncpg.Pool] = None

async def connect(self, url: str, table: str) -> None:
if not self.pool:
if self.connecting:
await self.connecting
else:
self.connecting = asyncio.Future()
logging.debug("creating pool for %s", table)
# this is helpful for connecting to an actually local db where your system username is different
# but counterproductive if you're proxying a database connection through localhost
# if "localhost" in self.database:
# pool = await asyncpg.create_pool(user="postgres")
self.pool = await asyncpg.create_pool(url)
logging.debug("created pool %s for %s", self.pool, table)
self.connecting.set_result(True)

def acquire(self) -> asyncpg.pool.PoolAcquireContext:
"""returns an async context manager. sugar around pool.pool.acquire
this *isn't* async, because pool.acquire returns an async context manager and not a coroutine"""
if not self.pool:
raise Exception("no pool, use pool.connect first")
return self.pool.acquire()

async def close(self) -> None:
try:
await pool.close()
if self.pool:
await self.pool.close()
except (asyncpg.PostgresError, asyncpg.InternalClientError) as e:
logging.error(e)


pool = OneTruePool()


class SimpleInterface:
def __init__(self, database: str) -> None:
self.database = database
self.pool: Optional[asyncpg.Pool] = None

@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
if not self.pool:
logging.info("creating pool")
if "localhost" in self.database:
self.pool = await asyncpg.create_pool(user="postgres") # self.database)
else:
self.pool = await asyncpg.create_pool(self.database)
pools.append(self.pool)
async with self.pool.acquire() as conn:
if not pool.pool:
pool.connect(self.database, "simple interface")
assert pool.pool
async with pool.acquire() as conn:
logging.info("connection acquired")
yield conn

Expand Down Expand Up @@ -104,19 +128,18 @@ def __init__(
self.queries = query_strings
self.table = self.queries.table
self.MAX_RESP_LOG_LEN = MAX_RESP_LOG_LEN
# self.loop.create_task(self.connect_pg())
self.pool = None
# self.loop.create_task(pool.connect_pg(database, self.table))
if isinstance(database, dict):
self.invocations: list[dict] = []
self.logger = get_logger(
f'{self.table}{"_fake" if not self.pool else ""}_interface'
f'{self.table}{"_fake" if not self.database else ""}_interface'
)

def finish_init(self) -> None:
"""Optionally triggers creating tables and checks existence."""
if not self.pool:
if not pool.pool:
self.logger.warning("RUNNING IN FAKE MODE")
if self.pool and self.table and not self.sync_exists():
if pool.pool and self.table and not self.sync_exists():
if AUTOCREATE:
self.sync_create_table()
self.logger.warning(f"building table {self.table}")
Expand All @@ -129,10 +152,6 @@ def finish_init(self) -> None:
self.logger.info(f"creating index via {k}")
self.__getattribute__(f"sync_{k}")()

async def connect_pg(self) -> None:
self.pool = await asyncpg.create_pool(self.database)
pools.append(self.pool)

_autocreating_table = False

async def execute(
Expand All @@ -142,10 +161,10 @@ async def execute(
) -> Optional[list[asyncpg.Record]]:
"""Invoke the asyncpg connection's `_execute` given a provided query string and set of arguments"""
timeout: int = 180
if not self.pool and not isinstance(self.database, dict):
await self.connect_pg()
if self.pool:
async with self.pool.acquire() as connection:
if not pool.pool and not isinstance(self.database, dict):
await pool.connect(self.database, self.table)
if pool.pool:
async with pool.acquire() as connection:
# try:
# except asyncpg.TooManyConnectionsError:
# await connection.execute(
Expand Down Expand Up @@ -185,9 +204,9 @@ def sync_execute(self, qstring: str, *args: Any) -> asyncpg.Record:
return ret

def sync_close(self) -> Any:
self.logger.info(f"closing connection: {self.pool}")
if self.pool:
ret = self.loop.run_until_complete(self.pool.close())
self.logger.info(f"closing connection: {pool.pool}")
if pool.pool:
ret = self.loop.run_until_complete(pool.pool.close())
return ret
return None

Expand Down Expand Up @@ -221,7 +240,7 @@ def __getattribute__(self, key: str) -> Callable[..., asyncpg.Record]:
statement = self.queries.get_query(qstring)
except KeyError as e:
raise ValueError(f"No statement of name {qstring} or {key} found!") from e
if not self.pool and isinstance(self.database, dict):
if not pool.pool and isinstance(self.database, dict):
canned_response = self.database.get(qstring, [[None]]).pop(0)
if qstring in self.database and not self.database.get(qstring, []):
self.database.pop(qstring)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_questions_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def bot():
bot.exiting = True
bot.handle_messages_task.cancel()
await bot.client_session.close()
await core.pghelp.close_pools()
await core.pghelp.pool.close()


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def bot():
bot.exiting = True
bot.handle_messages_task.cancel()
await bot.client_session.close()
await core.pghelp.close_pools()
await core.pghelp.pool.close()


@pytest.mark.asyncio
Expand Down

0 comments on commit 8059a15

Please sign in to comment.