Skip to content

Commit

Permalink
Replace sqlite3 with apsw
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed May 6, 2024
1 parent 0642482 commit d9dd4c3
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 96 deletions.
1 change: 1 addition & 0 deletions qcarchivetesting/conda-envs/fulltest_qcportal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- pyyaml
- pydantic
- zstandard
- apsw
- qcelemental
- tabulate
- tqdm
Expand Down
1 change: 1 addition & 0 deletions qcarchivetesting/conda-envs/fulltest_server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pyyaml
- pydantic
- zstandard
- apsw
- qcelemental
- tabulate
- tqdm
Expand Down
1 change: 1 addition & 0 deletions qcarchivetesting/conda-envs/fulltest_snowflake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pyyaml
- pydantic
- zstandard
- apsw
- qcelemental
- tabulate
- tqdm
Expand Down
1 change: 1 addition & 0 deletions qcarchivetesting/conda-envs/fulltest_worker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pyyaml
- pydantic
- zstandard
- apsw
- qcelemental
- tabulate
- tqdm
Expand Down
1 change: 1 addition & 0 deletions qcportal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"pyyaml",
"pydantic",
"zstandard",
"apsw",
"qcelemental",
"tabulate",
"tqdm",
Expand Down
145 changes: 54 additions & 91 deletions qcportal/qcportal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import datetime
import os
import sqlite3
import threading
import uuid
from typing import TYPE_CHECKING, Optional, TypeVar, Type, Any, List, Iterable, Tuple, Sequence
from urllib.parse import urlparse

import apsw

from .utils import chunk_iterable

try:
Expand All @@ -34,14 +34,14 @@
_query_chunk_size = 125


def compress_for_cache(data: Any) -> sqlite3.Binary:
def compress_for_cache(data: Any) -> bytes:
serialized_data = serialize(data, "msgpack")
compressed_data = zstandard.compress(serialized_data, level=1)
return sqlite3.Binary(compressed_data)
return compressed_data


def decompress_from_cache(data: sqlite3.Binary, value_type) -> Any:
decompressed_data = zstandard.decompress(bytes(data))
def decompress_from_cache(data: bytes, value_type) -> Any:
decompressed_data = zstandard.decompress(data)
deserialized_data = deserialize(decompressed_data, "msgpack")
return pydantic.parse_obj_as(value_type, deserialized_data)

Expand All @@ -51,30 +51,17 @@ def __init__(self, cache_uri: str, read_only: bool):
self.cache_uri = cache_uri
self.read_only = read_only

# There is a chance that this is used from multiple threads
# So store the connection in thread-local storage
self._th_local = threading.local()

if not read_only:
self._create_tables()
self._conn.commit()

@property
def _conn(self):
# Get the connection object that is used for this thread
if hasattr(self._th_local, "_conn"):
# print(threading.get_ident(), "EXISTING CONNECTION TO ", self.cache_uri)
return self._th_local._conn
if self.read_only:
self._conn = apsw.Connection(self.cache_uri, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
else:
# This thread currently doesn't have a connection. so create one
# Note the uri=True flag, which is needed for the sqlite3 module to recognize full URIs
# print(threading.get_ident(), "NEW CONNECTION TO ", self.cache_uri)
self._th_local._conn = sqlite3.connect(self.cache_uri, uri=True)
self._conn = apsw.Connection(
self.cache_uri, flags=apsw.SQLITE_OPEN_READWRITE | apsw.SQLITE_OPEN_CREATE | apsw.SQLITE_OPEN_URI
)

# Some common settings
self._th_local._conn.execute("PRAGMA foreign_keys = ON")
self._conn.pragma("foreign_keys", "ON")

return self._th_local._conn
if not read_only:
self._create_tables()

def __str__(self):
return f"<{self.__class__.__name__} path={self.cache_uri} {'ro' if self.read_only else 'rw'}>"
Expand Down Expand Up @@ -102,7 +89,6 @@ def update_metadata(self, key: str, value: Any) -> None:
self._assert_writable()
stmt = "REPLACE INTO metadata (key, value) VALUES (?, ?)"
self._conn.execute(stmt, (key, serialize(value, "msgpack")))
self._conn.commit()

def get_record(self, record_id: int, record_type: Type[_RECORD_T]) -> Optional[_RECORD_T]:
stmt = "SELECT record FROM records WHERE id = ?"
Expand Down Expand Up @@ -150,20 +136,20 @@ def get_existing_records(self, record_ids: Iterable[int]) -> List[int]:
def update_records(self, records: Iterable[_RECORD_T]):
self._assert_writable()

for record_batch in chunk_iterable(records, 10):
n_batch = len(record_batch)
with self._conn:
for record_batch in chunk_iterable(records, 10):
n_batch = len(record_batch)

values_params = ",".join(["(?, ?, ?, ?)"] * n_batch)
values_params = ",".join(["(?, ?, ?, ?)"] * n_batch)

all_params = []
for r in record_batch:
all_params.extend((r.id, r.status, r.modified_on.timestamp(), compress_for_cache(r)))
all_params = []
for r in record_batch:
all_params.extend((r.id, r.status, r.modified_on.timestamp(), compress_for_cache(r)))

stmt = f"REPLACE INTO records (id, status, modified_on, record) VALUES {values_params}"
stmt = f"REPLACE INTO records (id, status, modified_on, record) VALUES {values_params}"

self._conn.execute(stmt, all_params)
self._conn.execute(stmt, all_params)

self._conn.commit()
for r in records:
r._record_cache = self
r._cache_dirty = False
Expand All @@ -182,14 +168,12 @@ def writeback_record(self, record):
ts = record.modified_on.timestamp()
row_data = (record.id, record.status, ts, compressed_record, record.id, ts, ts, len(compressed_record))
self._conn.execute(stmt, row_data)
self._conn.commit()

def delete_record(self, record_id: int):
self._assert_writable()

stmt = "DELETE FROM records WHERE id=?"
self._conn.execute(stmt, (record_id,))
self._conn.commit()

def delete_records(self, record_ids: Iterable[int]):
self._assert_writable()
Expand All @@ -199,8 +183,6 @@ def delete_records(self, record_ids: Iterable[int]):
stmt = f"DELETE FROM records WHERE id IN ({record_id_params})"
self._conn.execute(stmt, record_id_batch)

self._conn.commit()


class DatasetCache(RecordCache):
def __init__(self, cache_uri: str, read_only: bool, dataset_type: Type[_DATASET_T]):
Expand Down Expand Up @@ -293,18 +275,17 @@ def update_entries(self, entries: Iterable[BaseModel]):

assert all(isinstance(e, self._entry_type) for e in entries)

for entry_batch in chunk_iterable(entries, 50):
n_batch = len(entry_batch)
values_params = ",".join(["(?, ?)"] * n_batch)

all_params = []
for e in entry_batch:
all_params.extend((e.name, compress_for_cache(e)))
with self._conn:
for entry_batch in chunk_iterable(entries, 50):
n_batch = len(entry_batch)
values_params = ",".join(["(?, ?)"] * n_batch)

stmt = f"REPLACE INTO dataset_entries (name, entry) VALUES {values_params}"
self._conn.execute(stmt, all_params)
all_params = []
for e in entry_batch:
all_params.extend((e.name, compress_for_cache(e)))

self._conn.commit()
stmt = f"REPLACE INTO dataset_entries (name, entry) VALUES {values_params}"
self._conn.execute(stmt, all_params)

def rename_entry(self, old_name: str, new_name: str):
self._assert_writable()
Expand All @@ -317,14 +298,12 @@ def rename_entry(self, old_name: str, new_name: str):

stmt = "UPDATE dataset_entries SET name=?, entry=? WHERE name=?"
self._conn.execute(stmt, (new_name, compress_for_cache(entry), old_name))
self._conn.commit()

def delete_entry(self, name):
self._assert_writable()

stmt = "DELETE FROM dataset_entries WHERE name=?"
self._conn.execute(stmt, (name,))
self._conn.commit()

def specification_exists(self, name: str) -> bool:
stmt = "SELECT 1 FROM dataset_specifications WHERE name=?"
Expand Down Expand Up @@ -358,18 +337,17 @@ def update_specifications(self, specifications: Iterable[BaseModel]):

assert all(isinstance(s, self._specification_type) for s in specifications)

for specification_batch in chunk_iterable(specifications, 50):
n_batch = len(specification_batch)
values_params = ",".join(["(?, ?)"] * n_batch)

all_params = []
for s in specification_batch:
all_params.extend((s.name, compress_for_cache(s)))
with self._conn:
for specification_batch in chunk_iterable(specifications, 50):
n_batch = len(specification_batch)
values_params = ",".join(["(?, ?)"] * n_batch)

stmt = f"REPLACE INTO dataset_specifications (name, specification) VALUES {values_params}"
self._conn.execute(stmt, all_params)
all_params = []
for s in specification_batch:
all_params.extend((s.name, compress_for_cache(s)))

self._conn.commit()
stmt = f"REPLACE INTO dataset_specifications (name, specification) VALUES {values_params}"
self._conn.execute(stmt, all_params)

def rename_specification(self, old_name: str, new_name: str):
self._assert_writable()
Expand All @@ -382,14 +360,12 @@ def rename_specification(self, old_name: str, new_name: str):

stmt = "UPDATE dataset_specifications SET name=?, specification=? WHERE name=?"
self._conn.execute(stmt, (new_name, compress_for_cache(specification), old_name))
self._conn.commit()

def delete_specification(self, name):
self._assert_writable()

stmt = "DELETE FROM dataset_specifications WHERE name=?"
self._conn.execute(stmt, (name,))
self._conn.commit()

def dataset_record_exists(self, entry_name: str, specification_name: str) -> bool:
stmt = "SELECT 1 FROM dataset_records WHERE entry_name=? and specification_name=?"
Expand Down Expand Up @@ -449,26 +425,24 @@ def update_dataset_records(self, record_info: Iterable[Tuple[str, str, _RECORD_T

assert all(isinstance(r, self._record_type) for _, _, r in record_info)

for info_batch in chunk_iterable(record_info, 10):
n_batch = len(info_batch)
values_params = ",".join(["(?, ?, ?)"] * n_batch)
with self._conn:
for info_batch in chunk_iterable(record_info, 10):
n_batch = len(info_batch)
values_params = ",".join(["(?, ?, ?)"] * n_batch)

all_params = []
for e, s, r in info_batch:
all_params.extend((e, s, r.id))
all_params = []
for e, s, r in info_batch:
all_params.extend((e, s, r.id))

stmt = f"""REPLACE INTO dataset_records (entry_name, specification_name, record_id)
VALUES {values_params}"""
self._conn.execute(stmt, all_params)

self._conn.commit()
stmt = f"""REPLACE INTO dataset_records (entry_name, specification_name, record_id)
VALUES {values_params}"""
self._conn.execute(stmt, all_params)

def delete_dataset_record(self, entry_name: str, specification_name: str):
self._assert_writable()

stmt = "DELETE FROM dataset_records WHERE entry_name=? AND specification_name=?"
self._conn.execute(stmt, (entry_name, specification_name))
self._conn.commit()

def delete_dataset_records(
self, entry_names: Optional[Iterable[str]], specification_names: Optional[Iterable[str]]
Expand All @@ -494,7 +468,6 @@ def delete_dataset_records(
stmt += " WHERE " + " AND ".join(conds)

self._conn.execute(stmt, all_params)
self._conn.commit()

def get_dataset_record_info(
self,
Expand Down Expand Up @@ -554,9 +527,7 @@ def get_existing_dataset_records(


class PortalCache:
def __init__(
self, server_uri: str, cache_dir: Optional[str], max_size: int, shared_memory_key: Optional[str] = None
):
def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int):
parsed_url = urlparse(server_uri)

# Should work as a reasonable fingerprint?
Expand All @@ -570,22 +541,14 @@ def __init__(
else:
self._is_disk = False

# If no shared memory key specified, make a unique one
if shared_memory_key is None:
shared_memory_key = f"{server_uri}_{os.getpid()}_{uuid.uuid4()}"

self._shared_memory_key = shared_memory_key
self.cache_dir = None

def get_cache_uri(self, cache_name: str) -> str:
if self._is_disk:
file_path = os.path.join(self.cache_dir, f"{cache_name}.sqlite")
uri = f"file:{file_path}"
else:
# We always want some shared cache due to the use of threads.
# vfs=memdb seems to be a better way than mode=memory&cache=shared . Very little docs about it though
# The / after the : is apparently very important. Otherwise, the shared stuff doesn't work
uri = f"file:/{self._shared_memory_key}_{cache_name}?vfs=memdb"
uri = ":memory:"

return uri

Expand Down Expand Up @@ -614,7 +577,7 @@ def read_dataset_metadata(file_path: str):
raise RuntimeError(f'Cannot open cache file "{file_path}" - does not exist or is not a file')

uri = f"file:{file_path}?mode=ro"
conn = sqlite3.connect(uri, uri=True)
conn = apsw.Connection(uri, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)

r = conn.execute("SELECT value FROM metadata WHERE key = 'dataset_metadata'")
if r is None:
Expand Down
6 changes: 1 addition & 5 deletions qcportal/qcportal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,11 @@ def __init__(
Directory to store an internal cache of records and other data
cache_max_size
Maximum size of the cache directory
memory_cache_key
If set, all clients with the same memory_cache_key will share an in-memory cache. If not specified,
a unique one will be generated, meaning this client will not share a memory-based cache with any
other clients. Not used if cache_dir is set.
"""

PortalClientBase.__init__(self, address, username, password, verify, show_motd)
self._logger = logging.getLogger("PortalClient")
self.cache = PortalCache(address, cache_dir, cache_max_size, memory_cache_key)
self.cache = PortalCache(address, cache_dir, cache_max_size)

def __repr__(self) -> str:
"""A short representation of the current PortalClient.
Expand Down

0 comments on commit d9dd4c3

Please sign in to comment.