From 0d685faa2cfa1db8e4d999b9dad2935d398923a7 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Fri, 20 Sep 2024 16:19:30 -0400 Subject: [PATCH] wip: postres integration --- sotodlib/core/metadata/common.py | 134 ++++++++++++++++++++------ sotodlib/core/metadata/obsdb.py | 158 +++++++++++++++++++------------ 2 files changed, 204 insertions(+), 88 deletions(-) diff --git a/sotodlib/core/metadata/common.py b/sotodlib/core/metadata/common.py index 393b31345..e7e86827b 100644 --- a/sotodlib/core/metadata/common.py +++ b/sotodlib/core/metadata/common.py @@ -1,9 +1,96 @@ -import sqlite3 +import psycopg import gzip import os +GET_TABLE_CREATE = """with table_info as ( + select + c.column_name, + c.data_type, + c.character_maximum_length, + c.is_nullable, + tc.constraint_type, + tc.constraint_name + from + information_schema.columns as c + left join ( + select + kcu.column_name, + tc.constraint_type, + tc.constraint_name, + kcu.table_name + from + information_schema.table_constraints as tc + join + information_schema.key_column_usage as kcu + on tc.constraint_name = kcu.constraint_name + where + tc.table_name = '%s' + ) as tc + on c.column_name = tc.column_name + where + c.table_name = '%s' + ) + select + 'create table your_table_name (' || string_agg( + column_name || ' ' || + data_type || + case + when character_maximum_length is not null + then '(' || character_maximum_length || ')' + else '' + end || + case + when is_nullable = 'no' then ' not null' + else '' + end || + case + when constraint_type is not null then + ' constraint ' || constraint_name || ' ' || constraint_type + else '' + end, + ', ' + ) || ');' + from + table_info; +""" -def sqlite_to_file(db, filename, overwrite=True, fmt=None): + +def dump_database(conn: psycopg.Connection) -> str: + with conn.cursor() as cur: + db_dump = "" + # Fetch all table names + cur.execute( + "select table_name from information_schema.tables where table_schema='public'" + ) + tables = cur.fetchall() + + for (table_name,) in tables: + # Dump CREATE TABLE statement + cur.execute(GET_TABLE_CREATE % (table_name, table_name)) + create_table = cur.fetchone()[0] + "\n" + db_dump += create_table + columns = cur.execute( + "select column_name, data_type, character_maximum_length from " + + "information_schema.columns where table_name = '%s';" % table_name + ).fetchall() + column_names = ", ".join(f"{col[0]}" for col in columns) + # Dump data + cur.execute(f"select {column_names} from {table_name}") + rows = cur.fetchall() + for row in rows: + values = ", ".join( + "NULL" if value is None else f"'{str(value)}'" for value in row + ) + db_dump += ( + f"insert into {table_name} ({column_names}) values ({values});\n" + ) + + return db_dump + + +def postgres_to_file( + db: psycopg.Connection, filename: str, overwrite: bool = True, fmt: str = None +) -> None: """Write an sqlite db to file. Supports several output formats. Args: @@ -19,29 +106,23 @@ def sqlite_to_file(db, filename, overwrite=True, fmt=None): if filename.endswith('.gz'): fmt = 'gz' else: - fmt = 'sqlite' + fmt = "dump" if os.path.exists(filename) and not overwrite: raise RuntimeError(f'File {filename} exists; remove or pass ' 'overwrite=True.') - if fmt == 'sqlite': - if os.path.exists(filename): - os.remove(filename) - new_db = sqlite3.connect(filename) - script = ' '.join(db.iterdump()) - new_db.executescript(script) - new_db.commit() - elif fmt == 'dump': + if fmt == "dump": with open(filename, 'w') as fout: - for line in db.iterdump(): + for line in dump_database(db): fout.write(line) elif fmt == 'gz': with gzip.GzipFile(filename, 'wb') as fout: - for line in db.iterdump(): - fout.write(line.encode('utf-8')) + for line in dump_database(db): + fout.write(line) else: raise RuntimeError(f'Unknown format "{fmt}" requested.') -def sqlite_from_file(filename, fmt=None, force_new_db=True): + +def postgres_from_file(filename: str, db: psycopg.Connection, fmt: str = None) -> None: """Instantiate an sqlite3.Connection and return it, with the data copied in from the specified file. The function can either map the database file directly, or map a copy of the database in memory (see force_new_db @@ -49,30 +130,27 @@ def sqlite_from_file(filename, fmt=None, force_new_db=True): Args: filename (str): path to the file. + db: A new DB connection. fmt (str): format of the input; see to_file for details. force_new_db (bool): Used if connecting to an sqlite database. If True the - databas is copied into memory and if False returns a connection to the + database is copied into memory and if False returns a connection to the database without reading it into memory """ if fmt is None: - fmt = 'sqlite' + fmt = "dump" if filename.endswith('.gz'): fmt = 'gz' - if fmt == 'sqlite': - db0 = sqlite3.connect(f'file:{filename}?mode=ro', uri=True) - if not force_new_db: - return db0 - data = ' '.join(db0.iterdump()) - elif fmt == 'dump': + if fmt == "dump": with open(filename, 'r') as fin: - data = fin.read() + data = fin.readlines() elif fmt == 'gz': with gzip.GzipFile(filename, 'r') as fin: - data = fin.read().decode('utf-8') + data = fin.readlines().decode("utf-8") else: raise RuntimeError(f'Unknown format "{fmt}" requested.') - db = sqlite3.connect(':memory:') - db.executescript(data) - return db + with db.cursor() as cursor: + for datum in data: + cursor.execute(datum.strip()) + db.commit() diff --git a/sotodlib/core/metadata/obsdb.py b/sotodlib/core/metadata/obsdb.py index ab88b9393..14b8ce39c 100644 --- a/sotodlib/core/metadata/obsdb.py +++ b/sotodlib/core/metadata/obsdb.py @@ -1,19 +1,20 @@ -import sqlite3 import os +import psycopg +from typing import Optional, Union from .resultset import ResultSet from . import common TABLE_DEFS = { - 'obs': [ - "`obs_id` varchar(256) primary key", - "`timestamp` float", + "obs": [ + "obs_id varchar(256) primary key", + "timestamp real", ], - 'tags': [ - "`obs_id` varchar(256)", - "`tag` varchar(256)", - "CONSTRAINT one_tag UNIQUE (`obs_id`, `tag`)", + "tags": [ + "obs_id varchar(256)", + "tag varchar(256)", + "CONSTRAINT one_tag UNIQUE (obs_id, tag)", ], } @@ -37,55 +38,63 @@ class ObsDb(object): """ TABLE_TEMPLATE = [ - "`obs_id` varchar(256)", + "obs_id varchar(256)", ] - def __init__(self, map_file=None, init_db=True): + def __init__(self, map_file: psycopg.Connection = None, init_db: bool = True): """Instantiate an ObsDb. Args: - map_file (str or sqlite3.Connection): If this is a string, - it will be treated as the filename for the sqlite3 - database, and opened as an sqlite3.Connection. If this is - an sqlite3.Connection, it is cached and used. If this + map_file (psycopg.Connection): This is + an psycopg.Connection, it is cached and used. If this argument is None (the default), then the - sqlite3.Connection is opened on ':memory:'. + psycopg.Connection is opened in 'localhost:5432'. init_db (bool): If True, then any ObsDb tables that do not already exist in the database will be created. Notes: If map_file is provided, the database will be connected to - the indicated sqlite file on disk, and any changes made to + the indicated postgres server, and any changes made to this object be written back to the file. """ - if isinstance(map_file, sqlite3.Connection): + if isinstance(map_file, psycopg.Connection): self.conn = map_file else: - if map_file is None: - map_file = ':memory:' - self.conn = sqlite3.connect(map_file) + raise RuntimeError("map_file is not a postgres") - self.conn.row_factory = sqlite3.Row # access columns by name + # self.conn.row_factory = psycopg.Row # access columns by name + # Values are by default returned in tuples if init_db: - c = self.conn.cursor() - c.execute("SELECT name FROM sqlite_master " - "WHERE type='table' and name not like 'sqlite_%';") - tables = [r[0] for r in c] + # c is the cursor + obsdb_cursor = self.conn.cursor() + obsdb_cursor.execute( + "select table_name" + + " from information_schema.tables" + + " where table_type='BASE TABLE'" + + " and table_schema='public';" + ) + tables = [r[0] for r in obsdb_cursor] changes = False for k, v in TABLE_DEFS.items(): if k not in tables: - q = ('create table if not exists `%s` (' % k + - ','.join(v) + ')') - c.execute(q) + create_query = ( + f"create table if not exists {k} (" + ",".join(v) + ")" + ) + obsdb_cursor.execute(create_query) changes = True if changes: self.conn.commit() def __len__(self): - return self.conn.execute('select count(obs_id) from obs').fetchone()[0] - - def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): + return self.conn.execute("select count(obs_id) from obs").fetchone()[0] + + def add_obs_columns( + self, + column_defs: Union[list[tuple[str, str]]], + ignore_duplicates: Optional[bool] = True, + commit: Optional[bool] = True, + ) -> "ObsDb": """Add columns to the obs table. Args: @@ -117,7 +126,10 @@ def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): 'timestamp float, drift str' """ - current_cols = self.conn.execute('pragma table_info("obs")').fetchall() + current_cols = self.conn.execute( + "select column_name, data_type, character_maximum_length from " + + "information_schema.columns where table_name = 'obs';" + ).fetchall() current_cols = [r[1] for r in current_cols] if isinstance(column_defs, str): column_defs = column_defs.split(',') @@ -126,7 +138,7 @@ def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): column_def = column_def.split() name, typestr = column_def if typestr is float: - typestr = 'float' + typestr = "real" elif typestr is int: typestr = 'int' elif typestr is str: @@ -139,14 +151,20 @@ def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): if check_name in current_cols: if ignore_duplicates: continue - raise ValueError("Column %s already exists in table obs" % check_name) - self.conn.execute('ALTER TABLE obs ADD COLUMN %s %s' % (name, typestr)) + raise ValueError(f"Column {check_name} already exists in table obs") + self.conn.execute(f"alter table obs add column {name}, {typestr}") current_cols.append(check_name) if commit: self.conn.commit() return self - def update_obs(self, obs_id, data={}, tags=[], commit=True): + def update_obs( + self, + obs_id: str, + data: dict = {}, + tags: Optional[list[str]] = [], + commit: Optional[bool] = True, + ): """Update an entry in the obs table. Arguments: @@ -160,22 +178,29 @@ def update_obs(self, obs_id, data={}, tags=[], commit=True): self. """ - c = self.conn.cursor() - c.execute('INSERT OR IGNORE INTO obs (obs_id) VALUES (?)', - (obs_id,)) + obsdb_cursor = self.conn.cursor() + obsdb_cursor.execute( + f"insert into obs (obs_id) values ({obs_id}) on conflict (obs_id) do nothing", + ) + if len(data.keys()): - settors = [f'{k}=?' for k in data.keys()] - c.execute('update obs set ' + ','.join(settors) + ' ' - 'where obs_id=?', - tuple(data.values()) + (obs_id, )) + settors = [f"{key} = %s" for key in data.keys()] + obsdb_cursor.execute( + "update obs set " + ", ".join(settors) + " where obs_id = %s", + tuple(data.values()) + (obs_id,), + ) + for t in tags: if t[0] == '!': # Kill this tag. - c.execute('DELETE FROM tags WHERE obs_id=? AND tag=?', - (obs_id, t[1:])) + obsdb_cursor.execute( + "delete from tags where obs_id = %s and tag = %s", (obs_id, t[1:]) + ) else: - c.execute('INSERT OR REPLACE INTO tags (obs_id, tag) ' - 'VALUES (?,?)', (obs_id, t)) + obsdb_cursor.execute( + f"insert into tags (obs_id, tag) values ({obs_id}, {t}) " + "on conflict (obs_id, tag) do update set obs_id = excluded.obs_id, tag = excluded.tag", + ) if commit: self.conn.commit() return self @@ -195,8 +220,10 @@ def copy(self, map_file=None, overwrite=False): raise RuntimeError("Output file %s exists (overwrite=True " "to overwrite)." % map_file) new_db = ObsDb(map_file=map_file, init_db=False) - script = ' '.join(self.conn.iterdump()) - new_db.conn.executescript(script) + script = common.dump_database(self.conn) + for line in script: + new_db.conn.execute(line.strip()) + new_db.conn.commit() return new_db def to_file(self, filename, overwrite=True, fmt=None): @@ -206,18 +233,24 @@ def to_file(self, filename, overwrite=True, fmt=None): filename (str): the path to the output file. overwrite (bool): whether an existing file should be overwritten. - fmt (str): 'sqlite', 'dump', or 'gz'. Defaults to 'sqlite' + fmt (str): 'dump', or 'gz'. Defaults to 'dump' unless the filename ends with '.gz', in which it is 'gz'. """ - return common.sqlite_to_file(self.conn, filename, overwrite=overwrite, fmt=fmt) + return common.postgres_to_file( + self.conn, filename, overwrite=overwrite, fmt=fmt + ) @classmethod - def from_file(cls, filename, fmt=None, force_new_db=True): + def from_file( + cls, filename: str, conn: psycopg.Connection, fmt=None, force_new_db=True + ) -> "ObsDb": """This method calls :func:`sotodlib.core.metadata.common.sqlite_from_file` """ - conn = common.sqlite_from_file(filename, fmt=fmt, force_new_db=force_new_db) + conn = common.postgres_from_file( + filename, conn, fmt=fmt, force_new_db=force_new_db + ) return cls(conn, init_db=False) def get(self, obs_id=None, tags=None, add_prefix=''): @@ -253,7 +286,7 @@ def get(self, obs_id=None, tags=None, add_prefix=''): output['tags'] = [r[0] for r in c] return output - def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): + def query(self, query_text="", tags=None, sort=["obs_id"], add_prefix=""): """Queries the ObsDb using user-provided text. Returns a ResultSet. Args: @@ -291,7 +324,7 @@ def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): """ sort_text = '' if sort is not None and len(sort): - sort_text = ' ORDER BY ' + ','.join(sort) + sort_text = " order by " + ",".join(sort) joins = '' extra_fields = [] if tags is not None and len(tags): @@ -302,18 +335,23 @@ def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): val = None if val is None: join_type = 'left join' - extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') + extra_fields.append(f'coalesce(tt{tagi}.obs_id,"") != "" as {t}') elif val == '0': join_type = 'left join' - extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') + extra_fields.append(f'coalesce(tt{tagi}.obs_id,"") != "" as {t}') query_text += f' and {t}==0' else: join_type = 'join' extra_fields.append(f'1 as {t}') - joins += (f' {join_type} (select distinct obs_id from tags where tag="{t}") as tt{tagi} on ' - f'obs.obs_id = tt{tagi}.obs_id') + joins += ( + f" {join_type} (select distinct obs_id from tags where tag='{t}') as tt{tagi} on " + f"obs.obs_id = tt{tagi}.obs_id" + ) extra_fields = ''.join([','+f for f in extra_fields]) - q = 'select obs.* %s from obs %s where %s %s' % (extra_fields, joins, query_text, sort_text) + where_statement = "" + if len(query_text): + where_statement = f" where {query_text}" + q = f"select obs.* {extra_fields} from obs {joins} {where_statement} {sort_text}" c = self.conn.execute(q) results = ResultSet.from_cursor(c) if add_prefix is not None: