diff --git a/tm_admin/tmdb.py b/tm_admin/tmdb.py index 585e98ae..36d51a9e 100755 --- a/tm_admin/tmdb.py +++ b/tm_admin/tmdb.py @@ -28,7 +28,6 @@ from shapely.geometry import MultiPolygon, Polygon, Point, shape from datetime import datetime from osm_rawdata.pgasync import PostgresClient -# from osm_rawdata.postgres import uriParser, PostgresClient from progress.bar import Bar, PixelBar from tm_admin.types_tm import Userrole, Mappinglevel, Organizationtype, Taskcreationmode, Projectstatus, Permissions, Projectpriority, Projectdifficulty, Mappingtypes, Editors, Teamvisibility, Taskstatus from tm_admin.yamlfile import YamlFile @@ -146,7 +145,7 @@ async def getPage(self, sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count}" else: sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count} OFFSET {offset}" - # sql = f"SELECT {columns} FROM {self.table} ORDER BY id LIMIT {count} OFFSET {offset}" + # print(sql) result = await self.tmdb.execute(sql) data = list() @@ -169,7 +168,7 @@ async def getColumns(self, (dict): The table definition. """ sql = f"SELECT column_name, data_type,column_default FROM information_schema.columns WHERE table_name = '{table}' ORDER BY dtd_identifier;" - results = await self.tmdb.queryLocal(sql) + results = await self.tmdb.execute(sql) # log.info(f"There are {len(results)} columns in the TM '{table}' table") table = dict() for column in results: @@ -211,13 +210,13 @@ async def getAllData(self, Returns: (list): All the data from the table. """ - columns = self.getColumns(table) + columns = await self.getColumns(table) keys = self.columns columns = str(keys)[1:-1].replace("'", "") sql = f"SELECT {columns} FROM {table}" # sql = f"SELECT row_to_json({table}) as row FROM {table}" - results = self.tmdb.queryLocal(sql) + results = await self.tmdb.execute(sql) log.info(f"There are {len(results)} records in the TM '{table}' table") data = list() # this is actually faster than using row_to_json(), and the @@ -228,8 +227,14 @@ async def getAllData(self, return data async def getRecordCount(self): - sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{self.table}'::regclass;" - # print(sql) + table = None + # FIXME: we should cleanup this mess between US and British spelling + if self.table == 'organizations': + table = "organisations" + else: + table = self.table + sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{table}'::regclass;" + print(sql) result = await self.tmdb.execute(sql) log.debug(f"There are {result[0]['estimate']} records in {self.table}") @@ -268,7 +273,10 @@ async def writeAllData(self, # values += createSQLValues(record, self.config) # print(values) # bar.next() - x = eval(record) + if type(record) == str: + x = eval(record) + else: + x = record for key, val in x.items(): columns.append(key) # print(f"FIXME: {key} = {self.config[key]}") @@ -462,7 +470,10 @@ async def main(): async with asyncio.TaskGroup() as tg: index = 0 for block in range(0, entries, chunk): - data = await doit.getPage(block, chunk) + data = await doit.getPage(block, chunk - 1) + # data = await doit.tmdb.getPage2(block, chunk) + # log.debug(f"Dispatching thread {index} {block}:{block + chunk - 1}") + # await importThread(data, tmpg[index], doit) task = tg.create_task(importThread(data, tmpg[index], doit)) index += 1 tasks.append(task) @@ -470,16 +481,16 @@ async def main(): data = list # You have to love subtle cultural spelling differences. if args.table == 'organizations': - data = doit.getAllData('organisations') + data = await doit.getAllData('organisations') else: - data = doit.getAllData(args.table) + data = await doit.getAllData(args.table) entries = len(data) log.debug(f"There are {entries} entries in {args.table}") chunk = round(entries / cores) if entries < threshold: - importThread(data, tmpg[0], doit) + await importThread(data, tmpg[0], doit) quit() index = 0