diff --git a/qcarchivetesting/conda-envs/fulltest_qcportal.yaml b/qcarchivetesting/conda-envs/fulltest_qcportal.yaml index eeb78f26d..5f480abd6 100644 --- a/qcarchivetesting/conda-envs/fulltest_qcportal.yaml +++ b/qcarchivetesting/conda-envs/fulltest_qcportal.yaml @@ -11,6 +11,7 @@ dependencies: - pyyaml - pydantic - zstandard + - apsw - qcelemental - tabulate - tqdm diff --git a/qcarchivetesting/conda-envs/fulltest_server.yaml b/qcarchivetesting/conda-envs/fulltest_server.yaml index 3d7526751..f68d7e711 100644 --- a/qcarchivetesting/conda-envs/fulltest_server.yaml +++ b/qcarchivetesting/conda-envs/fulltest_server.yaml @@ -15,6 +15,7 @@ dependencies: - pyyaml - pydantic - zstandard + - apsw - qcelemental - tabulate - tqdm diff --git a/qcarchivetesting/conda-envs/fulltest_snowflake.yaml b/qcarchivetesting/conda-envs/fulltest_snowflake.yaml index 8458a7716..0774c42bd 100644 --- a/qcarchivetesting/conda-envs/fulltest_snowflake.yaml +++ b/qcarchivetesting/conda-envs/fulltest_snowflake.yaml @@ -16,6 +16,7 @@ dependencies: - pyyaml - pydantic - zstandard + - apsw - qcelemental - tabulate - tqdm diff --git a/qcarchivetesting/conda-envs/fulltest_worker.yaml b/qcarchivetesting/conda-envs/fulltest_worker.yaml index d1690defe..26afa2c62 100644 --- a/qcarchivetesting/conda-envs/fulltest_worker.yaml +++ b/qcarchivetesting/conda-envs/fulltest_worker.yaml @@ -15,6 +15,7 @@ dependencies: - pyyaml - pydantic - zstandard + - apsw - qcelemental - tabulate - tqdm diff --git a/qcportal/pyproject.toml b/qcportal/pyproject.toml index da9360c9c..3d379e724 100644 --- a/qcportal/pyproject.toml +++ b/qcportal/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyyaml", "pydantic", "zstandard", + "apsw", "qcelemental", "tabulate", "tqdm", diff --git a/qcportal/qcportal/cache.py b/qcportal/qcportal/cache.py index 12ed0254d..2db8e2ef3 100644 --- a/qcportal/qcportal/cache.py +++ b/qcportal/qcportal/cache.py @@ -6,12 +6,12 @@ import datetime import os -import sqlite3 -import threading import uuid from typing import TYPE_CHECKING, Optional, TypeVar, Type, Any, List, Iterable, Tuple, Sequence from urllib.parse import urlparse +import apsw + from .utils import chunk_iterable try: @@ -34,14 +34,14 @@ _query_chunk_size = 125 -def compress_for_cache(data: Any) -> sqlite3.Binary: +def compress_for_cache(data: Any) -> bytes: serialized_data = serialize(data, "msgpack") compressed_data = zstandard.compress(serialized_data, level=1) - return sqlite3.Binary(compressed_data) + return compressed_data -def decompress_from_cache(data: sqlite3.Binary, value_type) -> Any: - decompressed_data = zstandard.decompress(bytes(data)) +def decompress_from_cache(data: bytes, value_type) -> Any: + decompressed_data = zstandard.decompress(data) deserialized_data = deserialize(decompressed_data, "msgpack") return pydantic.parse_obj_as(value_type, deserialized_data) @@ -51,30 +51,17 @@ def __init__(self, cache_uri: str, read_only: bool): self.cache_uri = cache_uri self.read_only = read_only - # There is a chance that this is used from multiple threads - # So store the connection in thread-local storage - self._th_local = threading.local() - - if not read_only: - self._create_tables() - self._conn.commit() - - @property - def _conn(self): - # Get the connection object that is used for this thread - if hasattr(self._th_local, "_conn"): - # print(threading.get_ident(), "EXISTING CONNECTION TO ", self.cache_uri) - return self._th_local._conn + if self.read_only: + self._conn = apsw.Connection(self.cache_uri, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) else: - # This thread currently doesn't have a connection. so create one - # Note the uri=True flag, which is needed for the sqlite3 module to recognize full URIs - # print(threading.get_ident(), "NEW CONNECTION TO ", self.cache_uri) - self._th_local._conn = sqlite3.connect(self.cache_uri, uri=True) + self._conn = apsw.Connection( + self.cache_uri, flags=apsw.SQLITE_OPEN_READWRITE | apsw.SQLITE_OPEN_CREATE | apsw.SQLITE_OPEN_URI + ) - # Some common settings - self._th_local._conn.execute("PRAGMA foreign_keys = ON") + self._conn.pragma("foreign_keys", "ON") - return self._th_local._conn + if not read_only: + self._create_tables() def __str__(self): return f"<{self.__class__.__name__} path={self.cache_uri} {'ro' if self.read_only else 'rw'}>" @@ -102,7 +89,6 @@ def update_metadata(self, key: str, value: Any) -> None: self._assert_writable() stmt = "REPLACE INTO metadata (key, value) VALUES (?, ?)" self._conn.execute(stmt, (key, serialize(value, "msgpack"))) - self._conn.commit() def get_record(self, record_id: int, record_type: Type[_RECORD_T]) -> Optional[_RECORD_T]: stmt = "SELECT record FROM records WHERE id = ?" @@ -150,20 +136,20 @@ def get_existing_records(self, record_ids: Iterable[int]) -> List[int]: def update_records(self, records: Iterable[_RECORD_T]): self._assert_writable() - for record_batch in chunk_iterable(records, 10): - n_batch = len(record_batch) + with self._conn: + for record_batch in chunk_iterable(records, 10): + n_batch = len(record_batch) - values_params = ",".join(["(?, ?, ?, ?)"] * n_batch) + values_params = ",".join(["(?, ?, ?, ?)"] * n_batch) - all_params = [] - for r in record_batch: - all_params.extend((r.id, r.status, r.modified_on.timestamp(), compress_for_cache(r))) + all_params = [] + for r in record_batch: + all_params.extend((r.id, r.status, r.modified_on.timestamp(), compress_for_cache(r))) - stmt = f"REPLACE INTO records (id, status, modified_on, record) VALUES {values_params}" + stmt = f"REPLACE INTO records (id, status, modified_on, record) VALUES {values_params}" - self._conn.execute(stmt, all_params) + self._conn.execute(stmt, all_params) - self._conn.commit() for r in records: r._record_cache = self r._cache_dirty = False @@ -182,14 +168,12 @@ def writeback_record(self, record): ts = record.modified_on.timestamp() row_data = (record.id, record.status, ts, compressed_record, record.id, ts, ts, len(compressed_record)) self._conn.execute(stmt, row_data) - self._conn.commit() def delete_record(self, record_id: int): self._assert_writable() stmt = "DELETE FROM records WHERE id=?" self._conn.execute(stmt, (record_id,)) - self._conn.commit() def delete_records(self, record_ids: Iterable[int]): self._assert_writable() @@ -199,8 +183,6 @@ def delete_records(self, record_ids: Iterable[int]): stmt = f"DELETE FROM records WHERE id IN ({record_id_params})" self._conn.execute(stmt, record_id_batch) - self._conn.commit() - class DatasetCache(RecordCache): def __init__(self, cache_uri: str, read_only: bool, dataset_type: Type[_DATASET_T]): @@ -293,18 +275,17 @@ def update_entries(self, entries: Iterable[BaseModel]): assert all(isinstance(e, self._entry_type) for e in entries) - for entry_batch in chunk_iterable(entries, 50): - n_batch = len(entry_batch) - values_params = ",".join(["(?, ?)"] * n_batch) - - all_params = [] - for e in entry_batch: - all_params.extend((e.name, compress_for_cache(e))) + with self._conn: + for entry_batch in chunk_iterable(entries, 50): + n_batch = len(entry_batch) + values_params = ",".join(["(?, ?)"] * n_batch) - stmt = f"REPLACE INTO dataset_entries (name, entry) VALUES {values_params}" - self._conn.execute(stmt, all_params) + all_params = [] + for e in entry_batch: + all_params.extend((e.name, compress_for_cache(e))) - self._conn.commit() + stmt = f"REPLACE INTO dataset_entries (name, entry) VALUES {values_params}" + self._conn.execute(stmt, all_params) def rename_entry(self, old_name: str, new_name: str): self._assert_writable() @@ -317,14 +298,12 @@ def rename_entry(self, old_name: str, new_name: str): stmt = "UPDATE dataset_entries SET name=?, entry=? WHERE name=?" self._conn.execute(stmt, (new_name, compress_for_cache(entry), old_name)) - self._conn.commit() def delete_entry(self, name): self._assert_writable() stmt = "DELETE FROM dataset_entries WHERE name=?" self._conn.execute(stmt, (name,)) - self._conn.commit() def specification_exists(self, name: str) -> bool: stmt = "SELECT 1 FROM dataset_specifications WHERE name=?" @@ -358,18 +337,17 @@ def update_specifications(self, specifications: Iterable[BaseModel]): assert all(isinstance(s, self._specification_type) for s in specifications) - for specification_batch in chunk_iterable(specifications, 50): - n_batch = len(specification_batch) - values_params = ",".join(["(?, ?)"] * n_batch) - - all_params = [] - for s in specification_batch: - all_params.extend((s.name, compress_for_cache(s))) + with self._conn: + for specification_batch in chunk_iterable(specifications, 50): + n_batch = len(specification_batch) + values_params = ",".join(["(?, ?)"] * n_batch) - stmt = f"REPLACE INTO dataset_specifications (name, specification) VALUES {values_params}" - self._conn.execute(stmt, all_params) + all_params = [] + for s in specification_batch: + all_params.extend((s.name, compress_for_cache(s))) - self._conn.commit() + stmt = f"REPLACE INTO dataset_specifications (name, specification) VALUES {values_params}" + self._conn.execute(stmt, all_params) def rename_specification(self, old_name: str, new_name: str): self._assert_writable() @@ -382,14 +360,12 @@ def rename_specification(self, old_name: str, new_name: str): stmt = "UPDATE dataset_specifications SET name=?, specification=? WHERE name=?" self._conn.execute(stmt, (new_name, compress_for_cache(specification), old_name)) - self._conn.commit() def delete_specification(self, name): self._assert_writable() stmt = "DELETE FROM dataset_specifications WHERE name=?" self._conn.execute(stmt, (name,)) - self._conn.commit() def dataset_record_exists(self, entry_name: str, specification_name: str) -> bool: stmt = "SELECT 1 FROM dataset_records WHERE entry_name=? and specification_name=?" @@ -449,26 +425,24 @@ def update_dataset_records(self, record_info: Iterable[Tuple[str, str, _RECORD_T assert all(isinstance(r, self._record_type) for _, _, r in record_info) - for info_batch in chunk_iterable(record_info, 10): - n_batch = len(info_batch) - values_params = ",".join(["(?, ?, ?)"] * n_batch) + with self._conn: + for info_batch in chunk_iterable(record_info, 10): + n_batch = len(info_batch) + values_params = ",".join(["(?, ?, ?)"] * n_batch) - all_params = [] - for e, s, r in info_batch: - all_params.extend((e, s, r.id)) + all_params = [] + for e, s, r in info_batch: + all_params.extend((e, s, r.id)) - stmt = f"""REPLACE INTO dataset_records (entry_name, specification_name, record_id) - VALUES {values_params}""" - self._conn.execute(stmt, all_params) - - self._conn.commit() + stmt = f"""REPLACE INTO dataset_records (entry_name, specification_name, record_id) + VALUES {values_params}""" + self._conn.execute(stmt, all_params) def delete_dataset_record(self, entry_name: str, specification_name: str): self._assert_writable() stmt = "DELETE FROM dataset_records WHERE entry_name=? AND specification_name=?" self._conn.execute(stmt, (entry_name, specification_name)) - self._conn.commit() def delete_dataset_records( self, entry_names: Optional[Iterable[str]], specification_names: Optional[Iterable[str]] @@ -494,7 +468,6 @@ def delete_dataset_records( stmt += " WHERE " + " AND ".join(conds) self._conn.execute(stmt, all_params) - self._conn.commit() def get_dataset_record_info( self, @@ -554,9 +527,7 @@ def get_existing_dataset_records( class PortalCache: - def __init__( - self, server_uri: str, cache_dir: Optional[str], max_size: int, shared_memory_key: Optional[str] = None - ): + def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int): parsed_url = urlparse(server_uri) # Should work as a reasonable fingerprint? @@ -570,11 +541,6 @@ def __init__( else: self._is_disk = False - # If no shared memory key specified, make a unique one - if shared_memory_key is None: - shared_memory_key = f"{server_uri}_{os.getpid()}_{uuid.uuid4()}" - - self._shared_memory_key = shared_memory_key self.cache_dir = None def get_cache_uri(self, cache_name: str) -> str: @@ -582,10 +548,7 @@ def get_cache_uri(self, cache_name: str) -> str: file_path = os.path.join(self.cache_dir, f"{cache_name}.sqlite") uri = f"file:{file_path}" else: - # We always want some shared cache due to the use of threads. - # vfs=memdb seems to be a better way than mode=memory&cache=shared . Very little docs about it though - # The / after the : is apparently very important. Otherwise, the shared stuff doesn't work - uri = f"file:/{self._shared_memory_key}_{cache_name}?vfs=memdb" + uri = ":memory:" return uri @@ -614,7 +577,7 @@ def read_dataset_metadata(file_path: str): raise RuntimeError(f'Cannot open cache file "{file_path}" - does not exist or is not a file') uri = f"file:{file_path}?mode=ro" - conn = sqlite3.connect(uri, uri=True) + conn = apsw.Connection(uri, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI) r = conn.execute("SELECT value FROM metadata WHERE key = 'dataset_metadata'") if r is None: diff --git a/qcportal/qcportal/client.py b/qcportal/qcportal/client.py index 8937f757c..09d740ecc 100644 --- a/qcportal/qcportal/client.py +++ b/qcportal/qcportal/client.py @@ -147,15 +147,11 @@ def __init__( Directory to store an internal cache of records and other data cache_max_size Maximum size of the cache directory - memory_cache_key - If set, all clients with the same memory_cache_key will share an in-memory cache. If not specified, - a unique one will be generated, meaning this client will not share a memory-based cache with any - other clients. Not used if cache_dir is set. """ PortalClientBase.__init__(self, address, username, password, verify, show_motd) self._logger = logging.getLogger("PortalClient") - self.cache = PortalCache(address, cache_dir, cache_max_size, memory_cache_key) + self.cache = PortalCache(address, cache_dir, cache_max_size) def __repr__(self) -> str: """A short representation of the current PortalClient.