From 4775529e286936551265c96a8261df07fd7397a0 Mon Sep 17 00:00:00 2001 From: Rob Savoye Date: Mon, 12 Feb 2024 11:20:35 -0700 Subject: [PATCH] fix: Finish refactoring to use asyncpg --- tm_admin/tmdb.py | 193 +++++++++++++++++++++-------------------------- 1 file changed, 84 insertions(+), 109 deletions(-) diff --git a/tm_admin/tmdb.py b/tm_admin/tmdb.py index 36d51a9e..01f561a6 100755 --- a/tm_admin/tmdb.py +++ b/tm_admin/tmdb.py @@ -24,6 +24,7 @@ import sys from sys import argv import os +import time from shapely import wkb, get_coordinates from shapely.geometry import MultiPolygon, Polygon, Point, shape from datetime import datetime @@ -34,10 +35,12 @@ import concurrent.futures from cpuinfo import get_cpu_info import asyncio +from asyncpg import create_pool # from tm_admin.users.users import createSQLValues # from tm_admin.organizations.organizations import createSQLValues from tqdm import tqdm import tqdm.asyncio +import copy # Instantiate logger log = logging.getLogger(__name__) @@ -48,29 +51,33 @@ # The number of threads is based on the CPU cores info = get_cpu_info() # More threads. Shorter import time, higher CPU load. But this is a -# pretty low CPU load proces anyway, so more is good. -cores = info["count"] +# pretty low CPU load anyway, so more is good. +cores = info["count"] * 2 async def importThread( - data: list, - db: PostgresClient, - tm, -): + data: list, + pg: PostgresClient, + table: str, + ): """ Thread to handle importing Args: data (list): The list of records to import - db (PostgresClient): A database connection - tm (TMImport): the input handle + outuri (str): The output database """ # log.debug(f"There are {len(data)} data entries") - await tm.writeAllData(data, tm.table, db) + if table == 'organisations': + table = 'organizations' + tmi = TMImport(table) + await tmi.writeAllData(data, pg) return True class TMImport(object): - def __init__(self): + def __init__(self, + config: str = None, + ): """ This class contains support to accessing a Tasking Manager database, and importing it in the TM Admin database. This works because the TM Admin @@ -81,20 +88,24 @@ def __init__(self): The other change is in TM many columns are enums, but the database type is in. The integer values from TM are converted to the proper TM Admin enum value. + Args: + config (str): The YAML config file for this table Returns: (TMImport): An instance of this class """ self.tmdb = None - self.admindb = None self.table = None self.columns = list() self.data = list() self.config = dict() + if config: + self.table = config + yaml = YamlFile(f"{rootdir}/{config}/{config}.yaml") + # yaml.dump() + self.config = yaml.getEntries() async def connect(self, inuri: str, - outuri: str, - table: str, ): """ This class contains support to accessing a Tasking Manager database, and @@ -108,52 +119,14 @@ async def connect(self, Args: inuri (str): The URI for the TM database - outuri (str): The URI for the TM Admin database table (str): The table in the TM Admin database -""" + """ # The Tasking Manager database self.tmdb = PostgresClient() await self.tmdb.connect(inuri) # The TMAdmin database - self.admindb = PostgresClient() - await self.admindb.connect(outuri) self.columns = list() self.data = list() - self.table = table - - yaml = YamlFile(f"{rootdir}/{table}/{table}.yaml") - # yaml.dump() - self.config = yaml.getEntries() - - async def getPage(self, - offset: int, - count: int, - ): - """ - Return all the data in the table. - - Returns: - (list): The results of the query - """ - - - # columns = await self.getColumns(self.table) - # keys = self.columns - - # columns = str(keys)[1:-1].replace("'", "") - if offset == 0: - sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count}" - else: - sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count} OFFSET {offset}" - - # print(sql) - result = await self.tmdb.execute(sql) - data = list() - for record in result: - table = dict(record)['row'] - data.append(table) - - return data async def getColumns(self, table: str, @@ -226,31 +199,15 @@ async def getAllData(self, data.append(table) return data - async def getRecordCount(self): - table = None - # FIXME: we should cleanup this mess between US and British spelling - if self.table == 'organizations': - table = "organisations" - else: - table = self.table - sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{table}'::regclass;" - print(sql) - result = await self.tmdb.execute(sql) - log.debug(f"There are {result[0]['estimate']} records in {self.table}") - - return result[0]['estimate'] - async def writeAllData(self, data: list, - table: str, - db: PostgresClient, + pg: PostgresClient, ): """ Write the data into table in TM Admin. Args: data (list): The table data from TM - table str(): The table to get the columns for. """ # log.debug(f"Writing block {len(data)} to the database") if len(data) == 0: @@ -258,9 +215,6 @@ async def writeAllData(self, builtins = ['int32', 'int64', 'string', 'timestamp', 'bool'] - # tr = self.admindb.pg.transaction() - # await tr.start() - pbar = tqdm.tqdm(data) for record in pbar: # for record in data: @@ -275,6 +229,8 @@ async def writeAllData(self, # bar.next() if type(record) == str: x = eval(record) + elif 'row' in record: + x = eval(record['row']) else: x = record for key, val in x.items(): @@ -406,9 +362,9 @@ async def writeAllData(self, # foo = f"str(columns)[1:-1].replace("'", "") - sql = f"""INSERT INTO {table}({str(columns)[1:-1].replace("'", "")}) VALUES({values[:-2]})""" + sql = f"""INSERT INTO {self.table}({str(columns)[1:-1].replace("'", "")}) VALUES({values[:-2]})""" # print(sql) - results = await db.execute(sql) + results = await pg.execute(sql) #await tr.commit() #bar.finish() @@ -446,56 +402,75 @@ async def main(): ) doit = TMImport() - await doit.connect(args.inuri, args.outuri, args.table) - entries = await doit.getRecordCount() + await doit.connect(args.inuri) + entries = await doit.tmdb.getRecordCount(args.table) block = 0 chunk = round(entries / cores) # this is the size of the pages in records threshold = 10000 data = list() + tasks = list() tmpg = list() - tmpg = list() - tasks = list() - for i in range(0, cores + 1): - pg = PostgresClient() - await pg.connect(args.outuri) - tmpg.append(pg) + inpg = PostgresClient() + await inpg.connect(args.inuri) + # FIXME: this interestingly still had data corruption problmes. If we had + # to do this frequently, we'd want to paginate the data, but normally + # importing from Tasking Manager is a one time operation. + # async with inpg.pg.transaction(): + # for index in range(0, cores): + # # cur = await inpg.pg.cursor(f'SELECT row_to_json({args.table}) AS row FROM {args.table}') + # cur = await inpg.pg.cursor(f'SELECT * FROM {args.table} ORDER BY id') + # result = await cur.fetch(chunk) + # data.append(result) + # await cur.forward(chunk) + # Some tables in the input database are huge, and can either core # dump python, or have performance issues. Past a certain threshold # the data needs to be queried in pages instead of the entire table. - if entries > threshold: - futures = list() - async with asyncio.TaskGroup() as tg: - index = 0 - for block in range(0, entries, chunk): - data = await doit.getPage(block, chunk - 1) - # data = await doit.tmdb.getPage2(block, chunk) - # log.debug(f"Dispatching thread {index} {block}:{block + chunk - 1}") - # await importThread(data, tmpg[index], doit) - task = tg.create_task(importThread(data, tmpg[index], doit)) - index += 1 - tasks.append(task) - else: - data = list - # You have to love subtle cultural spelling differences. + # There seems to be issues with data corruption + futures = list() + async with asyncio.TaskGroup() as tg: + # index = 0 + # dsn = f"postgres://rob:fu=br@localhost/tm_admin" + # async with create_pool(min_size=2, max_size=cores, dsn=dsn) as pool: + # async with pool.acquire() as con: + start = 0 if args.table == 'organizations': - data = await doit.getAllData('organisations') + table = 'organisations' else: - data = await doit.getAllData(args.table) - - entries = len(data) - log.debug(f"There are {entries} entries in {args.table}") - chunk = round(entries / cores) - - if entries < threshold: - await importThread(data, tmpg[0], doit) - quit() + table = args.table + sql = f"SELECT * FROM {table} ORDER BY id" + print(sql) + log.warning(f"This operation may be slow for large datasets.") + data = await inpg.execute(sql) + for index in range(0, cores): + outpg = PostgresClient() + await outpg.connect(args.outuri) + # data = await inpg.getPage(start, chunk, args.table) + log.debug(f"Dispatching thread {index} {start}:{start + chunk}") + # await importThread(data[start:start + chunk], outpg, args.table) + task = tg.create_task(importThread(data[start:start + chunk], outpg, table)) + start += chunk + # tasks.append(task) + # time.sleep(1) + # else: + # data = list + # # You have to love subtle cultural spelling differences. + # if args.table == 'organizations': + # data = await doit.getAllData('organisations') + # else: + # data = await doit.getAllData(args.table) + + # entries = len(data) + # log.debug(f"There are {entries} entries in {args.table}") + # outpg = PostgresClient() + # await outpg.connect(args.outuri) + # await importThread(data, outpg, args.table) + # quit() index = 0 - data = await doit.getPage(0, entries) - importThread(data, tmpg[0], doit) if __name__ == "__main__": """This is just a hook so this file can be run standalone during development."""