Skip to content
This repository has been archived by the owner on Aug 5, 2024. It is now read-only.

Commit

Permalink
fix: Finish refactoring to use asyncpg
Browse files Browse the repository at this point in the history
  • Loading branch information
rsavoye committed Feb 12, 2024
1 parent 8113217 commit 4775529
Showing 1 changed file with 84 additions and 109 deletions.
193 changes: 84 additions & 109 deletions tm_admin/tmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
from sys import argv
import os
import time
from shapely import wkb, get_coordinates
from shapely.geometry import MultiPolygon, Polygon, Point, shape
from datetime import datetime
Expand All @@ -34,10 +35,12 @@
import concurrent.futures
from cpuinfo import get_cpu_info
import asyncio
from asyncpg import create_pool
# from tm_admin.users.users import createSQLValues
# from tm_admin.organizations.organizations import createSQLValues
from tqdm import tqdm
import tqdm.asyncio
import copy

# Instantiate logger
log = logging.getLogger(__name__)
Expand All @@ -48,29 +51,33 @@
# The number of threads is based on the CPU cores
info = get_cpu_info()
# More threads. Shorter import time, higher CPU load. But this is a
# pretty low CPU load proces anyway, so more is good.
cores = info["count"]
# pretty low CPU load anyway, so more is good.
cores = info["count"] * 2

async def importThread(
data: list,
db: PostgresClient,
tm,
):
data: list,
pg: PostgresClient,
table: str,
):
"""
Thread to handle importing
Args:
data (list): The list of records to import
db (PostgresClient): A database connection
tm (TMImport): the input handle
outuri (str): The output database
"""
# log.debug(f"There are {len(data)} data entries")
await tm.writeAllData(data, tm.table, db)
if table == 'organisations':
table = 'organizations'
tmi = TMImport(table)
await tmi.writeAllData(data, pg)

return True

class TMImport(object):
def __init__(self):
def __init__(self,
config: str = None,
):
"""
This class contains support to accessing a Tasking Manager database, and
importing it in the TM Admin database. This works because the TM Admin
Expand All @@ -81,20 +88,24 @@ def __init__(self):
The other change is in TM many columns are enums, but the database type
is in. The integer values from TM are converted to the proper TM Admin enum value.
Args:
config (str): The YAML config file for this table
Returns:
(TMImport): An instance of this class
"""
self.tmdb = None
self.admindb = None
self.table = None
self.columns = list()
self.data = list()
self.config = dict()
if config:
self.table = config
yaml = YamlFile(f"{rootdir}/{config}/{config}.yaml")
# yaml.dump()
self.config = yaml.getEntries()

async def connect(self,
inuri: str,
outuri: str,
table: str,
):
"""
This class contains support to accessing a Tasking Manager database, and
Expand All @@ -108,52 +119,14 @@ async def connect(self,
Args:
inuri (str): The URI for the TM database
outuri (str): The URI for the TM Admin database
table (str): The table in the TM Admin database
"""
"""
# The Tasking Manager database
self.tmdb = PostgresClient()
await self.tmdb.connect(inuri)
# The TMAdmin database
self.admindb = PostgresClient()
await self.admindb.connect(outuri)
self.columns = list()
self.data = list()
self.table = table

yaml = YamlFile(f"{rootdir}/{table}/{table}.yaml")
# yaml.dump()
self.config = yaml.getEntries()

async def getPage(self,
offset: int,
count: int,
):
"""
Return all the data in the table.
Returns:
(list): The results of the query
"""


# columns = await self.getColumns(self.table)
# keys = self.columns

# columns = str(keys)[1:-1].replace("'", "")
if offset == 0:
sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count}"
else:
sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count} OFFSET {offset}"

# print(sql)
result = await self.tmdb.execute(sql)
data = list()
for record in result:
table = dict(record)['row']
data.append(table)

return data

async def getColumns(self,
table: str,
Expand Down Expand Up @@ -226,41 +199,22 @@ async def getAllData(self,
data.append(table)
return data

async def getRecordCount(self):
table = None
# FIXME: we should cleanup this mess between US and British spelling
if self.table == 'organizations':
table = "organisations"
else:
table = self.table
sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{table}'::regclass;"
print(sql)
result = await self.tmdb.execute(sql)
log.debug(f"There are {result[0]['estimate']} records in {self.table}")

return result[0]['estimate']

async def writeAllData(self,
data: list,
table: str,
db: PostgresClient,
pg: PostgresClient,
):
"""
Write the data into table in TM Admin.
Args:
data (list): The table data from TM
table str(): The table to get the columns for.
"""
# log.debug(f"Writing block {len(data)} to the database")
if len(data) == 0:
return True

builtins = ['int32', 'int64', 'string', 'timestamp', 'bool']

# tr = self.admindb.pg.transaction()
# await tr.start()

pbar = tqdm.tqdm(data)
for record in pbar:
# for record in data:
Expand All @@ -275,6 +229,8 @@ async def writeAllData(self,
# bar.next()
if type(record) == str:
x = eval(record)
elif 'row' in record:
x = eval(record['row'])
else:
x = record
for key, val in x.items():
Expand Down Expand Up @@ -406,9 +362,9 @@ async def writeAllData(self,


# foo = f"str(columns)[1:-1].replace("'", "")
sql = f"""INSERT INTO {table}({str(columns)[1:-1].replace("'", "")}) VALUES({values[:-2]})"""
sql = f"""INSERT INTO {self.table}({str(columns)[1:-1].replace("'", "")}) VALUES({values[:-2]})"""
# print(sql)
results = await db.execute(sql)
results = await pg.execute(sql)
#await tr.commit()

#bar.finish()
Expand Down Expand Up @@ -446,56 +402,75 @@ async def main():
)

doit = TMImport()
await doit.connect(args.inuri, args.outuri, args.table)
entries = await doit.getRecordCount()
await doit.connect(args.inuri)
entries = await doit.tmdb.getRecordCount(args.table)
block = 0
chunk = round(entries / cores)

# this is the size of the pages in records
threshold = 10000
data = list()
tasks = list()
tmpg = list()

tmpg = list()
tasks = list()
for i in range(0, cores + 1):
pg = PostgresClient()
await pg.connect(args.outuri)
tmpg.append(pg)
inpg = PostgresClient()
await inpg.connect(args.inuri)
# FIXME: this interestingly still had data corruption problmes. If we had
# to do this frequently, we'd want to paginate the data, but normally
# importing from Tasking Manager is a one time operation.
# async with inpg.pg.transaction():
# for index in range(0, cores):
# # cur = await inpg.pg.cursor(f'SELECT row_to_json({args.table}) AS row FROM {args.table}')
# cur = await inpg.pg.cursor(f'SELECT * FROM {args.table} ORDER BY id')
# result = await cur.fetch(chunk)
# data.append(result)
# await cur.forward(chunk)

# Some tables in the input database are huge, and can either core
# dump python, or have performance issues. Past a certain threshold
# the data needs to be queried in pages instead of the entire table.
if entries > threshold:
futures = list()
async with asyncio.TaskGroup() as tg:
index = 0
for block in range(0, entries, chunk):
data = await doit.getPage(block, chunk - 1)
# data = await doit.tmdb.getPage2(block, chunk)
# log.debug(f"Dispatching thread {index} {block}:{block + chunk - 1}")
# await importThread(data, tmpg[index], doit)
task = tg.create_task(importThread(data, tmpg[index], doit))
index += 1
tasks.append(task)
else:
data = list
# You have to love subtle cultural spelling differences.
# There seems to be issues with data corruption
futures = list()
async with asyncio.TaskGroup() as tg:
# index = 0
# dsn = f"postgres://rob:fu=br@localhost/tm_admin"
# async with create_pool(min_size=2, max_size=cores, dsn=dsn) as pool:
# async with pool.acquire() as con:
start = 0
if args.table == 'organizations':
data = await doit.getAllData('organisations')
table = 'organisations'
else:
data = await doit.getAllData(args.table)

entries = len(data)
log.debug(f"There are {entries} entries in {args.table}")
chunk = round(entries / cores)

if entries < threshold:
await importThread(data, tmpg[0], doit)
quit()
table = args.table
sql = f"SELECT * FROM {table} ORDER BY id"
print(sql)
log.warning(f"This operation may be slow for large datasets.")
data = await inpg.execute(sql)
for index in range(0, cores):
outpg = PostgresClient()
await outpg.connect(args.outuri)
# data = await inpg.getPage(start, chunk, args.table)
log.debug(f"Dispatching thread {index} {start}:{start + chunk}")
# await importThread(data[start:start + chunk], outpg, args.table)
task = tg.create_task(importThread(data[start:start + chunk], outpg, table))
start += chunk
# tasks.append(task)
# time.sleep(1)
# else:
# data = list
# # You have to love subtle cultural spelling differences.
# if args.table == 'organizations':
# data = await doit.getAllData('organisations')
# else:
# data = await doit.getAllData(args.table)

# entries = len(data)
# log.debug(f"There are {entries} entries in {args.table}")
# outpg = PostgresClient()
# await outpg.connect(args.outuri)
# await importThread(data, outpg, args.table)
# quit()

index = 0
data = await doit.getPage(0, entries)
importThread(data, tmpg[0], doit)

if __name__ == "__main__":
"""This is just a hook so this file can be run standalone during development."""
Expand Down

0 comments on commit 4775529

Please sign in to comment.