diff --git a/sotodlib/core/metadata/obsdb.py b/sotodlib/core/metadata/obsdb.py index ab88b9393..1bcf01f72 100644 --- a/sotodlib/core/metadata/obsdb.py +++ b/sotodlib/core/metadata/obsdb.py @@ -17,6 +17,122 @@ ], } +def _generate_query_components_from_tags(query_text='1', tags=None): + """ + Generate query components from given tags. + + This function creates SQL query components based on the provided tags. + It generates join clauses, extra fields, and updates the query text to be included in the final query. + + Args: + query_text (str, optional): The initial query condition. Defaults to '1'. + tags (list of str, optional): Tags to include in the output; if they + are listed here then they can also be used in the query string. + Filtering on tag value can be done here by appending '=0' or '=1' to a tag name. + + Returns: + tuple: A tuple containing: + - extra_fields (str): Comma-separated string of extra fields for the SELECT clause. + - joins (str): String of join clauses to be added to the query. + - query_text (str): Updated query text including conditions for the tags. + """ + joins = '' + extra_fields = [] + if tags is not None and len(tags): + for tagi, t in enumerate(tags): + if '=' in t: + t, val = t.split('=') + else: + val = None + if val is None: + join_type = 'left join' + extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') + elif val == '0': + join_type = 'left join' + extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') + query_text += f' and {t}==0' + else: + join_type = 'join' + extra_fields.append(f'1 as {t}') + joins += (f' {join_type} (select distinct obs_id from tags where tag="{t}") as tt{tagi} on ' + f'obs.obs_id = tt{tagi}.obs_id') + extra_fields = ''.join([','+f for f in extra_fields]) + return extra_fields, joins, query_text + +def _generate_query_components_from_subdb(filepath, + alias, + query_list=None, + params_list=None, + table_name=None, + obs_id_name=None): + """ + Generate query components from a sub-database. + + This function creates SQL query components based on the provided sub-database information. + It generates join clauses, extra fields, and query conditions to be included in the final query. + + Args: + filepath (str): The file path to the sub-database. + alias (str): The alias to be used for the sub-database in the query. + query_list (list of str, optional): A list of query conditions to be applied on the sub-database. + params_list (list of str, optional): A list of parameters to be selected from the sub-database. + table_name (str, optional): The name of the table in the sub-database. Defaults to 'map'. + obs_id_name (str, optional): The name of the observation ID field in the sub-database. Defaults to 'obs:obs_id'. + + Returns: + tuple: A tuple containing: + - extra_fields (str): Comma-separated string of extra fields for the SELECT clause. + - join (str): String of join clause to be added to the query. + - query (str): String of query conditions to be added to the WHERE clause. + """ + if table_name is None: + table_name = 'map' + if obs_id_name is None: + obs_id_name = 'obs:obs_id' + + if params_list is not None and isinstance(params_list, list): + extra_fields = [] + for _param in params_list: + extra_fields.append(f'{alias}.{table_name}.{_param}') + extra_fields = ''.join([','+f for f in extra_fields]) + elif params_list is None: + extra_fields = '' + else: + raise ValueError('Invalid input for params_list') + + join = f' join {alias}.{table_name} on obs.obs_id = {alias}.{table_name}."{obs_id_name}"' + + if query_list is not None and isinstance(query_list, list): + query = [] + operators = [' < ', ' > ', ' <= ', ' >= '] + for _query_component in query_list: + if ' in ' in _query_component: + _query_component = _query_component.split(' in ') + _query_component = _query_component[1]+' LIKE '+'"%'+_query_component[0]+'%"' + query.append(f'{alias}.{table_name}.{_query_component}') + elif any(op in _query_component for op in operators): + for op in operators: + if op in _query_component: + _query_component = _query_component.split(op) + source_name = _query_component[0] + distance = float(_query_component[1]) + _query_component = f'{alias}.{table_name}.source_distance' + query.append(f"INSTR({_query_component}, '{source_name}:') > 0" + f" AND CAST(SUBSTR({_query_component}, " + f"INSTR({_query_component}, '{source_name}:') + LENGTH('{source_name}:'), " + f"IFNULL(NULLIF(INSTR(SUBSTR({_query_component}, " + f"INSTR({_query_component}, '{source_name}:') + LENGTH('{source_name}:')), ','), 0), " + f" LENGTH({_query_component})) - 1) AS REAL) {op} {distance}") + break + else: + query.append(f'{alias}.{table_name}.{_query_component}') + query = ''.join([' and '+_q for _q in query]) + elif query_list is None: + query = '' + else: + raise ValueError('Invalid input for query_list') + + return extra_fields, join, query class ObsDb(object): """Observation database. @@ -253,69 +369,130 @@ def get(self, obs_id=None, tags=None, add_prefix=''): output['tags'] = [r[0] for r in c] return output - def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): - """Queries the ObsDb using user-provided text. Returns a ResultSet. - - Args: - query_text (str): The sqlite query string. All fields - should refer to the obs table, or to tags explicitly - listed in the tags argument. - tags (list of str): Tags to include in the output; if they - are listed here then they can also be used in the query - string. Filtering on tag value can be done here by - appending '=0' or '=1' to a tag name. - + def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix='', subdbs_info_list=None): + """ + Generate and execute a query on the main database with optional sub-databases. + + This function constructs and executes an SQL query on the main database, + incorporating conditions, joins, and fields from optional sub-databases and tags. + + Args: + query_text (str, optional): The initial query condition. Defaults to '1'. + tags (list of str, optional): A list of tags to filter the observations. + See _generate_query_components_from_tags for details. + sort (list of str, optional): A list of fields to sort the results by. Defaults to ['obs_id']. + add_prefix (str, optional): A prefix to add to the result keys. Defaults to ''. + subdbs_info_list (list of dict, optional): A list of dictionaries containing sub-database information. + Each dictionary should contain, filepath (str), query_list (list of str, optional), params_list + (list of str, optional), table_name (str, optional), and obs_id_name (str, optional). + See the Notes(2) for details. If not provided, only the query for the main obsdb is executed. + Returns: - A ResultSet with one row for each Observation matching the - criteria. + ResultSet: The result set of the executed query. Notes: - Tags are added to the output on request. For example, - passing tags=['planet','stare'] will cause the output to - include columns 'planet' and 'stare' in addition to all the - columns defined in the obs table. The value of 'planet' and - 'stare' in each row will be 0 or 1 depending on whether that - tag is set for that observation. We can include expressions - involving planet and stare in the query, for example:: - - obsdb.query('planet=1 or stare=1', tags=['planet', 'stare']) - - For simple filtering on tags, pass '=1' or '=0', like this:: - - obsdb.query(tags=['planet=1','hwp=1']) - - When filtering is activated in this way, the returned - results must satisfy all the criteria (i.e. the individual - constraints are AND-ed). - + (1) Tags are added to the output on request. For example, + passing tags=['planet','stare'] will cause the output to + include columns 'planet' and 'stare' in addition to all the + columns defined in the obs table. The value of 'planet' and + 'stare' in each row will be 0 or 1 depending on whether that + tag is set for that observation. We can include expressions + involving planet and stare in the query, for example:: + + obsdb.query('planet=1 or stare=1', tags=['planet', 'stare']) + + For simple filtering on tags, pass '=1' or '=0', like this:: + + obsdb.query(tags=['planet=1','hwp=1']) + + When filtering is activated in this way, the returned + results must satisfy all the criteria (i.e. the individual + constraints are AND-ed). + + (2) Sub-databases can be attached and queried in the main query. + For example, passing subdbs_info_list with appropriate + parameters allows the function to include extra fields, + joins, and query conditions from the sub-databases. Each + sub-database is attached with a unique alias and is used + to join the main obsdb. The query is then constructed + to include the necessary fields and conditions from both + the main and sub-databases. For instance:: + + subdb_info = { + 'filepath': '/path/to/pwv_class.sqlite', + 'query_list': ['pwv_class_median<2.0', 'pwv_class_rms<0.1'], + 'params_list': ['pwv_class_median', 'pwv_class_rms'], + } + obsdb.query(query_text='start_time>1700000000 and planet=1', + tags=['planet'], + subdbs_info_list=[subdb_info]) + + This queries observations with a start_time greater than 1700000000, + a tag of planet, a median pwv smaller than 2.0 [mm], and an rms of pwv + smaller than 0.1 [mm]. + + If you do not know the parameters in the sub-database, you can view the params + as long as it is a ManifestDb like below:: + + from sotodlib.core import metadata + subdb = metadata.ManifestDb('/path/to/pwv_class.sqlite') + print(subdb.scheme._get_map_table_def()) + """ - sort_text = '' - if sort is not None and len(sort): - sort_text = ' ORDER BY ' + ','.join(sort) - joins = '' - extra_fields = [] - if tags is not None and len(tags): - for tagi, t in enumerate(tags): - if '=' in t: - t, val = t.split('=') - else: - val = None - if val is None: - join_type = 'left join' - extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') - elif val == '0': - join_type = 'left join' - extra_fields.append(f'ifnull(tt{tagi}.obs_id,"") != "" as {t}') - query_text += f' and {t}==0' - else: - join_type = 'join' - extra_fields.append(f'1 as {t}') - joins += (f' {join_type} (select distinct obs_id from tags where tag="{t}") as tt{tagi} on ' - f'obs.obs_id = tt{tagi}.obs_id') - extra_fields = ''.join([','+f for f in extra_fields]) - q = 'select obs.* %s from obs %s where %s %s' % (extra_fields, joins, query_text, sort_text) - c = self.conn.execute(q) - results = ResultSet.from_cursor(c) + cursor = self.conn.cursor() + + try: + extra_fields_main, joins_main, query_text_main = \ + _generate_query_components_from_tags(query_text=query_text, + tags=tags) + sort_text = '' + if sort is not None and len(sort): + sort_text = ' ORDER BY ' + ','.join(sort) + + if subdbs_info_list is not None: + assert isinstance(subdbs_info_list, list) + extra_fields_sub = [] + joins_sub = [] + query_text_sub = [] + + aliases = [] + for i, subdb_info in enumerate(subdbs_info_list): + assert isinstance(subdb_info, dict) + if 'filepath' not in subdb_info.keys(): + raise KeyError(f'subdb_info does not have "filepath" in keys') + filepath = subdb_info['filepath'] + alias = f'subdb{i}' + aliases.append(alias) + attach = f"ATTACH DATABASE '{filepath}' AS '{alias}'" + cursor = cursor.execute(attach) + _extra_fields_sub, _join_sub, _query_sub = \ + _generate_query_components_from_subdb(filepath=filepath, + alias=alias, + query_list=subdb_info.get('query_list', None), + params_list=subdb_info.get('params_list', None), + table_name=subdb_info.get('table_name', None), + obs_id_name=subdb_info.get('obs_id_name', None),) + extra_fields_sub.append(_extra_fields_sub) + joins_sub.append(_join_sub) + query_text_sub.append(_query_sub) + extra_fields_sub = ''.join([''+f for f in extra_fields_sub]) + joins_sub = ''.join(' '+_j for _j in joins_sub) + query_text_sub = ''.join(' '+q for q in query_text_sub) + tot_query = f'SELECT obs.* {extra_fields_main} {extra_fields_sub} FROM obs {joins_main} {joins_sub} WHERE {query_text_main} {query_text_sub} {sort_text}' + cursor = cursor.execute(tot_query) + results = ResultSet.from_cursor(cursor) + + for alias in aliases: + cursor.execute(f"DETACH DATABASE {alias}") + + else: + tot_query = f'SELECT obs.* {extra_fields_main} FROM obs {joins_main} WHERE {query_text_main} {sort_text}' + cursor = cursor.execute(tot_query) + results = ResultSet.from_cursor(cursor) + + finally: + cursor.close() + if add_prefix is not None: results.keys = [add_prefix + k for k in results.keys] return results diff --git a/tests/test_obsdb.py b/tests/test_obsdb.py index 6b249a006..72cf8ad09 100644 --- a/tests/test_obsdb.py +++ b/tests/test_obsdb.py @@ -1,4 +1,7 @@ import unittest +import shutil +import tempfile +import numpy as np from sotodlib.core import metadata import os @@ -6,7 +9,6 @@ from ._helpers import mpi_multi - def get_example(): # Create a new Db and add two columns. obsdb = metadata.ObsDb() @@ -26,13 +28,68 @@ def get_example(): 'drift': 'rising' if i%2 else 'setting'}, tags=tags) return obsdb + +def make_extdb_pwv(dir_name, file_name): + # create a extenal manifestdb for pwv, and return temporal filepath to the database + scheme = metadata.ManifestScheme() + scheme.add_exact_match('obs:obs_id') + scheme.add_data_field('pwv') + extdb_pwv = metadata.ManifestDb(map_file=os.path.join(dir_name, file_name), + scheme=scheme) + for i in range(10): + if i in [3, 8]: + pwv = np.random.uniform(2.1, 3) + else: + pwv = np.random.uniform(0, 2) + entry_dict = {'obs:obs_id': f'myobs{i}', + 'pwv': pwv} + extdb_pwv.add_entry(entry_dict) + return os.path.join(dir_name, file_name) +def make_extdb_coverage(dir_name, file_name): + # create a extenal manifestdb for source coverage, and return temporal filepath to the database + scheme = metadata.ManifestScheme() + scheme.add_exact_match('obs:obs_id') + scheme.add_data_field('coverage') + extdb_coverage = metadata.ManifestDb(map_file=os.path.join(dir_name, file_name), + scheme=scheme) + for i in range(10): + if i in [3, 8]: + coverage = 'jupiter:ws0,jupiter:ws1,saturn:ws0' + else: + coverage = 'jupiter:ws1,jupiter:ws2,saturn:ws0' + entry_dict = {'obs:obs_id': f'myobs{i}', + 'coverage': coverage} + extdb_coverage.add_entry(entry_dict) + return os.path.join(dir_name, file_name) -@unittest.skipIf(mpi_multi(), "Running with multiple MPI processes") +def make_extdb_distance(dir_name, file_name): + # create a extenal manifestdb for source coverage, and return temporal filepath to the database + scheme = metadata.ManifestScheme() + scheme.add_exact_match('obs:obs_id') + scheme.add_data_field('source_distance') + extdb_distance = metadata.ManifestDb(map_file=os.path.join(dir_name, file_name), + scheme=scheme) + for i in range(10): + if i in [3, 8]: + distance = 'saturn:10.1,jupiter:12.4' + else: + distance = 'saturn:9.5,neptune:13.5' + entry_dict = {'obs:obs_id': f'myobs{i}', + 'source_distance': distance} + extdb_distance.add_entry(entry_dict) + return os.path.join(dir_name, file_name) + +# @unittest.skipIf(mpi_multi(), "Running with multiple MPI processes") class TestObsDb(unittest.TestCase): def setUp(self): + self.test_dir = tempfile.mkdtemp() pass + + def tearDown(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) def test_smoke(self): """Basic functionality.""" @@ -49,6 +106,99 @@ def test_query(self): r1 = db.query('drift == "setting"') self.assertGreater(len(r0), 0) self.assertEqual(len(r0) + len(r1), len(db)) + + def test_query_extension(self): + db = get_example() + extdb_path = make_extdb_pwv(dir_name=self.test_dir, file_name='pwv.sqlite') + r0 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['pwv<2.0'], + 'params_list': ['pwv']}] + ) + r1 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['pwv<2.0'], + 'params_list': ['pwv']}] + ) + r2 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['pwv>=2.0'], + 'params_list': ['pwv']}] + ) + r3 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['pwv>=2.0'], + 'params_list': ['pwv']}] + ) + self.assertEqual(len(r0)+len(r1)+len(r2)+len(r3), + len(db)) + self.assertTrue('pwv' in r0.keys) + + def test_query_extension_coverage(self): + db = get_example() + extdb_path = make_extdb_coverage(dir_name=self.test_dir, file_name='coverage.sqlite') + r0 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['jupiter:ws0 in coverage'], + 'params_list': ['coverage']}] + ) + r1 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['jupiter:ws0 in coverage'], + 'params_list': ['coverage']}] + ) + r2 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['jupiter:ws2 in coverage'], + 'params_list': ['coverage']}] + ) + r3 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['jupiter:ws2 in coverage'], + 'params_list': ['coverage']}] + ) + r4 = db.query( + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['jupiter:ws1 in coverage'], + 'params_list': ['coverage']}] + ) + self.assertEqual(len(r0)+len(r1)+len(r2)+len(r3), + len(db)) + self.assertEqual(len(r4), len(db)) + self.assertTrue('coverage' in r0.keys) + + def test_query_extension_distance(self): + db = get_example() + extdb_path = make_extdb_distance(dir_name=self.test_dir, file_name='distance.sqlite') + r0 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['saturn > 10'], + 'params_list': ['source_distance']}] + ) + r1 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['saturn > 10'], + 'params_list': ['source_distance']}] + ) + r2 = db.query('drift == "rising"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['neptune >= 12'], + 'params_list': ['source_distance']}] + ) + r3 = db.query('drift == "setting"', + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['neptune >= 12'], + 'params_list': ['source_distance']}] + ) + r4 = db.query( + subdbs_info_list = [{'filepath': extdb_path, + 'query_list': ['saturn < 11'], + 'params_list': ['source_distance']}] + ) + self.assertEqual(len(r0)+len(r1)+len(r2)+len(r3), + len(db)) + self.assertEqual(len(r4), len(db)) + self.assertTrue('source_distance' in r0.keys) def test_tags(self): db = get_example()