Skip to content
This repository has been archived by the owner on Aug 5, 2024. It is now read-only.

Commit

Permalink
fix: Minor changes to use pgasync everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
rsavoye committed Feb 9, 2024
1 parent 998e0a4 commit 8113217
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions tm_admin/tmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from shapely.geometry import MultiPolygon, Polygon, Point, shape
from datetime import datetime
from osm_rawdata.pgasync import PostgresClient
# from osm_rawdata.postgres import uriParser, PostgresClient
from progress.bar import Bar, PixelBar
from tm_admin.types_tm import Userrole, Mappinglevel, Organizationtype, Taskcreationmode, Projectstatus, Permissions, Projectpriority, Projectdifficulty, Mappingtypes, Editors, Teamvisibility, Taskstatus
from tm_admin.yamlfile import YamlFile
Expand Down Expand Up @@ -146,7 +145,7 @@ async def getPage(self,
sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count}"
else:
sql = f"SELECT row_to_json({self.table}) as row FROM {self.table} LIMIT {count} OFFSET {offset}"
# sql = f"SELECT {columns} FROM {self.table} ORDER BY id LIMIT {count} OFFSET {offset}"

# print(sql)
result = await self.tmdb.execute(sql)
data = list()
Expand All @@ -169,7 +168,7 @@ async def getColumns(self,
(dict): The table definition.
"""
sql = f"SELECT column_name, data_type,column_default FROM information_schema.columns WHERE table_name = '{table}' ORDER BY dtd_identifier;"
results = await self.tmdb.queryLocal(sql)
results = await self.tmdb.execute(sql)
# log.info(f"There are {len(results)} columns in the TM '{table}' table")
table = dict()
for column in results:
Expand Down Expand Up @@ -211,13 +210,13 @@ async def getAllData(self,
Returns:
(list): All the data from the table.
"""
columns = self.getColumns(table)
columns = await self.getColumns(table)
keys = self.columns

columns = str(keys)[1:-1].replace("'", "")
sql = f"SELECT {columns} FROM {table}"
# sql = f"SELECT row_to_json({table}) as row FROM {table}"
results = self.tmdb.queryLocal(sql)
results = await self.tmdb.execute(sql)
log.info(f"There are {len(results)} records in the TM '{table}' table")
data = list()
# this is actually faster than using row_to_json(), and the
Expand All @@ -228,8 +227,14 @@ async def getAllData(self,
return data

async def getRecordCount(self):
sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{self.table}'::regclass;"
# print(sql)
table = None
# FIXME: we should cleanup this mess between US and British spelling
if self.table == 'organizations':
table = "organisations"
else:
table = self.table
sql = f"SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'public.{table}'::regclass;"
print(sql)
result = await self.tmdb.execute(sql)
log.debug(f"There are {result[0]['estimate']} records in {self.table}")

Expand Down Expand Up @@ -268,7 +273,10 @@ async def writeAllData(self,
# values += createSQLValues(record, self.config)
# print(values)
# bar.next()
x = eval(record)
if type(record) == str:
x = eval(record)
else:
x = record
for key, val in x.items():
columns.append(key)
# print(f"FIXME: {key} = {self.config[key]}")
Expand Down Expand Up @@ -462,24 +470,27 @@ async def main():
async with asyncio.TaskGroup() as tg:
index = 0
for block in range(0, entries, chunk):
data = await doit.getPage(block, chunk)
data = await doit.getPage(block, chunk - 1)
# data = await doit.tmdb.getPage2(block, chunk)
# log.debug(f"Dispatching thread {index} {block}:{block + chunk - 1}")
# await importThread(data, tmpg[index], doit)
task = tg.create_task(importThread(data, tmpg[index], doit))
index += 1
tasks.append(task)
else:
data = list
# You have to love subtle cultural spelling differences.
if args.table == 'organizations':
data = doit.getAllData('organisations')
data = await doit.getAllData('organisations')
else:
data = doit.getAllData(args.table)
data = await doit.getAllData(args.table)

entries = len(data)
log.debug(f"There are {entries} entries in {args.table}")
chunk = round(entries / cores)

if entries < threshold:
importThread(data, tmpg[0], doit)
await importThread(data, tmpg[0], doit)
quit()

index = 0
Expand Down

0 comments on commit 8113217

Please sign in to comment.