Skip to content
134 changes: 66 additions & 68 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

library_dir: Path | None = None
storage_path: Path | str | None = None
engine: Engine | None = None
folder: Folder | None = None
included_files: set[Path] = set()
Expand All @@ -232,7 +231,6 @@ def close(self):
if self.engine:
self.engine.dispose()
self.library_dir = None
self.storage_path = None
self.folder = None
self.included_files = set()

Expand Down Expand Up @@ -348,33 +346,36 @@ def tag_display_name(self, tag: Tag | None) -> str:
else:
return tag.name

def open_library(
self, library_dir: Path, storage_path: Path | str | None = None
) -> LibraryStatus:
is_new: bool = True
if storage_path == ":memory:":
self.storage_path = storage_path
is_new = True
return self.open_sqlite_library(library_dir, is_new)
else:
self.storage_path = library_dir / TS_FOLDER_NAME / SQL_FILENAME
assert isinstance(self.storage_path, Path)
if self.verify_ts_folder(library_dir) and (is_new := not self.storage_path.exists()):
json_path = library_dir / TS_FOLDER_NAME / JSON_FILENAME
if json_path.exists():
return LibraryStatus(
success=False,
library_path=library_dir,
message="[JSON] Legacy v9.4 library requires conversion to v9.5+",
json_migration_req=True,
)
def open_library(self, library_dir: Path, in_memory: bool = False) -> LibraryStatus:
"""Wrapper for open_sqlite_library.

Handles in-memory storage and checks whether a JSON-migration is necessary.
"""
assert isinstance(library_dir, Path)

if in_memory:
return self.open_sqlite_library(library_dir, is_new=True, storage_path=":memory:")

is_new = True
sql_path = library_dir / TS_FOLDER_NAME / SQL_FILENAME
if self.verify_ts_folder(library_dir) and (is_new := not sql_path.exists()):
json_path = library_dir / TS_FOLDER_NAME / JSON_FILENAME
if json_path.exists():
return LibraryStatus(
success=False,
library_path=library_dir,
message="[JSON] Legacy v9.4 library requires conversion to v9.5+",
json_migration_req=True,
)

return self.open_sqlite_library(library_dir, is_new)
return self.open_sqlite_library(library_dir, is_new, str(sql_path))

def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
def open_sqlite_library(
self, library_dir: Path, is_new: bool, storage_path: str
) -> LibraryStatus:
connection_string = URL.create(
drivername="sqlite",
database=str(self.storage_path),
database=storage_path,
)
# NOTE: File-based databases should use NullPool to create new DB connection in order to
# keep connections on separate threads, which prevents the DB files from being locked
Expand All @@ -383,7 +384,7 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
# More info can be found on the SQLAlchemy docs:
# https://docs.sqlalchemy.org/en/20/changelog/migration_07.html
# Under -> sqlite-the-sqlite-dialect-now-uses-nullpool-for-file-based-databases
poolclass = None if self.storage_path == ":memory:" else NullPool
poolclass = None if storage_path == ":memory:" else NullPool
loaded_db_version: int = 0

logger.info(
Expand Down Expand Up @@ -421,8 +422,8 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
logger.info(f"[Library] DB_VERSION: {loaded_db_version}")
make_tables(self.engine)

# Add default tag color namespaces.
if is_new:
# Add default tag color namespaces.
namespaces = default_color_groups.namespaces()
try:
session.add_all(namespaces)
Expand All @@ -431,8 +432,7 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
logger.error("[Library] Couldn't add default tag color namespaces", error=e)
session.rollback()

# Add default tag colors.
if is_new:
# Add default tag colors.
tag_colors: list[TagColorGroup] = default_color_groups.standard()
tag_colors += default_color_groups.pastels()
tag_colors += default_color_groups.shades()
Expand All @@ -447,8 +447,7 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
logger.error("[Library] Couldn't add default tag colors", error=e)
session.rollback()

# Add default tags.
if is_new:
# Add default tags.
tags = get_default_tags()
try:
session.add_all(tags)
Expand Down Expand Up @@ -529,35 +528,36 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:

# Apply any post-SQL migration patches.
if not is_new:
assert loaded_db_version >= 6

# save backup if patches will be applied
if loaded_db_version < DB_VERSION:
self.library_dir = library_dir
self.save_library_backup_to_disk()
self.library_dir = None

# NOTE: Depending on the data, some data and schema changes need to be applied in
# different orders. This chain of methods can likely be cleaned up and/or moved.
# migrate DB step by step from one version to the next
if loaded_db_version < 7:
# changes: value_type, tags
self.__apply_db7_migration(session)
if loaded_db_version < 8:
self.__apply_db8_schema_changes(session)
if loaded_db_version < 9:
self.__apply_db9_schema_changes(session)
if loaded_db_version < 103:
self.__apply_db103_schema_changes(session)
if loaded_db_version == 6:
self.__apply_repairs_for_db6(session)

if loaded_db_version >= 6 and loaded_db_version < 8:
self.__apply_db8_default_data(session)
# changes: tag_colors
self.__apply_db8_migration(session)
if loaded_db_version < 9:
self.__apply_db9_filename_population(session)
# changes: entries
self.__apply_db9_migration(session)
if loaded_db_version < 100:
self.__apply_db100_parent_repairs(session)
# changes: tag_parents
self.__apply_db100_migration(session)
if loaded_db_version < 102:
self.__apply_db102_repairs(session)
# changes: tag_parents
self.__apply_db102_migration(session)
if loaded_db_version < 103:
self.__apply_db103_default_data(session)
# changes: tags
self.__apply_db103_migration(session)

# Convert file extension list to ts_ignore file, if a .ts_ignore file does not exist
# TODO: do this in the migration step that will remove the preferences table
self.migrate_sql_to_ts_ignore(library_dir)

# Update DB_VERSION
Expand All @@ -568,8 +568,8 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
self.library_dir = library_dir
return LibraryStatus(success=True, library_path=library_dir)

def __apply_repairs_for_db6(self, session: Session):
"""Apply database repairs introduced in DB_VERSION 7."""
def __apply_db7_migration(self, session: Session):
"""Migrate DB from DB_VERSION 6 to 7."""
logger.info("[Library][Migration] Applying patches to DB_VERSION: 6 library...")
with session:
# Repair "Description" fields with a TEXT_LINE key instead of a TEXT_BOX key.
Expand All @@ -582,7 +582,7 @@ def __apply_repairs_for_db6(self, session: Session):
session.flush()

# Repair tags that may have a disambiguation_id pointing towards a deleted tag.
all_tag_ids: set[int] = {tag.id for tag in self.tags}
all_tag_ids = session.scalars(text("SELECT DISTINCT id FROM tags")).all()
disam_stmt = (
update(Tag)
.where(Tag.disambiguation_id.not_in(all_tag_ids))
Expand All @@ -591,9 +591,8 @@ def __apply_repairs_for_db6(self, session: Session):
session.execute(disam_stmt)
session.commit()

def __apply_db8_schema_changes(self, session: Session):
"""Apply database schema changes introduced in DB_VERSION 8."""
# TODO: Use Alembic for this part instead
def __apply_db8_migration(self, session: Session):
"""Migrate DB from DB_VERSION 7 to 8."""
# Add the missing color_border column to the TagColorGroups table.
color_border_stmt = text(
"ALTER TABLE tag_colors ADD COLUMN color_border BOOLEAN DEFAULT FALSE NOT NULL"
Expand All @@ -609,8 +608,7 @@ def __apply_db8_schema_changes(self, session: Session):
)
session.rollback()

def __apply_db8_default_data(self, session: Session):
"""Apply default data changes introduced in DB_VERSION 8."""
# collect new default tag colors
tag_colors: list[TagColorGroup] = default_color_groups.standard()
tag_colors += default_color_groups.pastels()
tag_colors += default_color_groups.shades()
Expand Down Expand Up @@ -659,8 +657,9 @@ def __apply_db8_default_data(self, session: Session):
)
session.rollback()

def __apply_db9_schema_changes(self, session: Session):
"""Apply database schema changes introduced in DB_VERSION 9."""
def __apply_db9_migration(self, session: Session):
"""Migrate DB from DB_VERSION 8 to 9."""
# Apply database schema changes
add_filename_column = text(
"ALTER TABLE entries ADD COLUMN filename TEXT NOT NULL DEFAULT ''"
)
Expand All @@ -675,15 +674,14 @@ def __apply_db9_schema_changes(self, session: Session):
)
session.rollback()

def __apply_db9_filename_population(self, session: Session):
"""Populate the filename column introduced in DB_VERSION 9."""
# Populate the new filename column.
for entry in self.all_entries():
session.merge(entry).filename = entry.path.name
session.commit()
logger.info("[Library][Migration] Populated filename column in entries table")

def __apply_db100_parent_repairs(self, session: Session):
"""Swap the child_id and parent_id values in the TagParent table."""
def __apply_db100_migration(self, session: Session):
"""Migrate DB to DB_VERSION 100."""
with session:
# Repair parent-child tag relationships that are the wrong way around.
stmt = update(TagParent).values(
Expand All @@ -694,17 +692,18 @@ def __apply_db100_parent_repairs(self, session: Session):
session.commit()
logger.info("[Library][Migration] Refactored TagParent table")

def __apply_db102_repairs(self, session: Session):
"""Repair tag_parents rows with references to deleted tags."""
def __apply_db102_migration(self, session: Session):
"""Migrate DB to DB_VERSION 102."""
with session:
all_tag_ids: list[int] = [t.id for t in self.tags]
all_tag_ids = session.scalars(text("SELECT DISTINCT id FROM tags")).all()
stmt = delete(TagParent).where(TagParent.parent_id.not_in(all_tag_ids))
session.execute(stmt)
session.commit()
logger.info("[Library][Migration] Verified TagParent table data")

def __apply_db103_schema_changes(self, session: Session):
"""Apply database schema changes introduced in DB_VERSION 103."""
def __apply_db103_migration(self, session: Session):
"""Migrate DB from DB_VERSION 102 to 103."""
# add the new hidden column for tags
add_is_hidden_column = text(
"ALTER TABLE tags ADD COLUMN is_hidden BOOLEAN NOT NULL DEFAULT 0"
)
Expand All @@ -719,8 +718,7 @@ def __apply_db103_schema_changes(self, session: Session):
)
session.rollback()

def __apply_db103_default_data(self, session: Session):
"""Apply default data changes introduced in DB_VERSION 103."""
# mark the "Archived" tag as hidden
try:
session.query(Tag).filter(Tag.id == TAG_ARCHIVED).update({"is_hidden": True})
session.commit()
Expand Down
5 changes: 3 additions & 2 deletions src/tagstudio/qt/mixed/migration_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,12 @@ def migration_iterator(self):
self.temp_path: Path = (
self.json_lib.library_dir / TS_FOLDER_NAME / "migration_ts_library.sqlite"
)
self.sql_lib.storage_path = self.temp_path
if self.temp_path.exists():
logger.info('Temporary migration file "temp_path" already exists. Removing...')
self.temp_path.unlink()
self.sql_lib.open_sqlite_library(self.json_lib.library_dir, is_new=True)
self.sql_lib.open_sqlite_library(
self.json_lib.library_dir, is_new=True, storage_path=str(self.temp_path)
)
yield Translations.format(
"json_migration.migrating_files_entries", entries=len(self.json_lib.entries)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cwd():
def file_mediatypes_library():
lib = Library()

status = lib.open_library(Path(""), ":memory:")
status = lib.open_library(Path(""), in_memory=True)
assert status.success
folder = unwrap(lib.folder)

Expand Down Expand Up @@ -84,7 +84,7 @@ def library(request, library_dir: Path): # pyright: ignore
library_path = Path(request.param)

lib = Library()
status = lib.open_library(library_path, ":memory:")
status = lib.open_library(library_path, in_memory=True)
assert status.success
folder = unwrap(lib.folder)

Expand Down