diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index ddb1a7bbe..9ea9316ec 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -42,6 +42,7 @@ text, update, ) +from sqlalchemy.dialects import sqlite from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import ( InstanceState, @@ -1478,7 +1479,7 @@ def add_tag( return None def add_tags_to_entries( - self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int] + self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int] ) -> int: """Add one or more tags to one or more entries. @@ -1494,45 +1495,57 @@ def add_tags_to_entries( entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids + values: list[tuple[int, int]] = [] + for tag_id in tag_ids_: + values.extend((tag_id, entry_id) for entry_id in entry_ids_) + with Session(self.engine, expire_on_commit=False) as session: - for tag_id in tag_ids_: - for entry_id in entry_ids_: - try: - session.add(TagEntry(tag_id=tag_id, entry_id=entry_id)) - total_added += 1 - session.commit() - except IntegrityError: - session.rollback() + for sub_list in [ + values[i : i + MAX_SQL_VARIABLES // 2] + for i in range(0, len(values), MAX_SQL_VARIABLES // 2) + ]: + stmt = ( + sqlite.insert(TagEntry) + .values(sub_list) + .on_conflict_do_nothing() + .returning(TagEntry) + ) + added = session.scalars(stmt).all() + total_added += len(added) + session.commit() return total_added def remove_tags_from_entries( - self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int] - ) -> bool: + self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int] + ): """Remove one or more tags from one or more entries.""" - entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids - tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids + logger.info( + "[Library][remove_tags_from_entries]", + entry_ids=entry_ids, + tag_ids=tag_ids, + ) + + entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else list(entry_ids) + tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else list(tag_ids) + with Session(self.engine, expire_on_commit=False) as session: - try: - for tag_id in tag_ids_: - for entry_id in entry_ids_: - tag_entry = session.scalars( - select(TagEntry).where( - and_( - TagEntry.tag_id == tag_id, - TagEntry.entry_id == entry_id, - ) - ) - ).first() - if tag_entry: - session.delete(tag_entry) - session.flush() - session.commit() - return True - except IntegrityError as e: - logger.error(e) - session.rollback() - return False + for tags_sub_list in [ + tag_ids_[i : i + MAX_SQL_VARIABLES // 2] + for i in range(0, len(tag_ids_), MAX_SQL_VARIABLES // 2) + ]: + for entries_sub_list in [ + entry_ids_[i : i + MAX_SQL_VARIABLES // 2] + for i in range(0, len(entry_ids_), MAX_SQL_VARIABLES // 2) + ]: + stmt = delete(TagEntry).where( + and_( + TagEntry.tag_id.in_(tags_sub_list), + TagEntry.entry_id.in_(entries_sub_list), + ) + ) + session.execute(stmt) + session.commit() def add_color(self, color_group: TagColorGroup) -> TagColorGroup | None: with Session(self.engine, expire_on_commit=False) as session: