Skip to content

Commit

Permalink
PersistentDict: disable same-thread check on threadsafe SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jun 10, 2024
1 parent c9012fa commit c8adf8e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
48 changes: 48 additions & 0 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@
.. autofunction:: unique
SQLite-related functions
------------------------
.. autofunction:: get_sqlite3_thread_safety_level
.. autofunction:: is_sqlite3_fully_threadsafe
Type Variables Used
-------------------
Expand Down Expand Up @@ -2975,6 +2981,48 @@ def unique(seq: Sequence[T]) -> Iterator[T]:
# }}}


# {{{ SQLite-related functions

def get_sqlite3_thread_safety_level() -> int:
"""Return the thread safety value of the underlying SQLite library in
Python's DBAPI 2.0 format."""
# Based on https://ricardoanderegg.com/posts/python-sqlite-thread-safety/
import sqlite3

# Map value from SQLite's THREADSAFE to Python's DBAPI 2.0
# threadsafety attribute.
sqlite_threadsafe2python_dbapi = {0: 0, 2: 1, 1: 3}
conn = sqlite3.connect(":memory:")
threadsafety = conn.execute(
"""
select * from pragma_compile_options
where compile_options like 'THREADSAFE=%'
"""
).fetchone()[0]
conn.close()

threadsafety_value = int(threadsafety.split("=")[1])
threadsafety_value_db = sqlite_threadsafe2python_dbapi[threadsafety_value]

import sys
if sys.version_info < (3, 11):
assert threadsafety_value == 1
else:
assert threadsafety_value_db == sqlite3.threadsafety

return threadsafety_value_db


def is_sqlite3_fully_threadsafe() -> bool:
"""Check if the underlying SQLite library is fully thread-safe."""
if get_sqlite3_thread_safety_level() == 3:
return True
else:
return False

# }}}


def _test():
import doctest
doctest.testmod()
Expand Down
6 changes: 5 additions & 1 deletion pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,13 @@ def __init__(self, identifier: str,
self.container_dir = container_dir
self._make_container_dir()

from pytools import is_sqlite3_fully_threadsafe

# isolation_level=None: enable autocommit mode
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
self.conn = sqlite3.connect(self.filename, isolation_level=None)
self.conn = sqlite3.connect(self.filename,
isolation_level=None,
check_same_thread=not is_sqlite3_fully_threadsafe())

self._exec_sql(
"CREATE TABLE IF NOT EXISTS dict "
Expand Down
5 changes: 5 additions & 0 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,11 @@ def test_concurrency() -> None:
# }}}


def test_sqlite_threadsafety() -> None:
from pytools.sqlite import get_sqlite3_thread_safety_level
get_sqlite3_thread_safety_level()


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit c8adf8e

Please sign in to comment.