Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
text,
update,
)
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
InstanceState,
Expand Down Expand Up @@ -1478,7 +1479,7 @@ def add_tag(
return None

def add_tags_to_entries(
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
) -> int:
"""Add one or more tags to one or more entries.

Expand All @@ -1494,45 +1495,57 @@ def add_tags_to_entries(

entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
values: list[tuple[int, int]] = []
for tag_id in tag_ids_:
values.extend((tag_id, entry_id) for entry_id in entry_ids_)

with Session(self.engine, expire_on_commit=False) as session:
for tag_id in tag_ids_:
for entry_id in entry_ids_:
try:
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
total_added += 1
session.commit()
except IntegrityError:
session.rollback()
for sub_list in [
values[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(values), MAX_SQL_VARIABLES // 2)
]:
stmt = (
sqlite.insert(TagEntry)
.values(sub_list)
.on_conflict_do_nothing()
.returning(TagEntry)
)
added = session.scalars(stmt).all()
total_added += len(added)
session.commit()

return total_added

def remove_tags_from_entries(
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
) -> bool:
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
):
"""Remove one or more tags from one or more entries."""
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
logger.info(
"[Library][remove_tags_from_entries]",
entry_ids=entry_ids,
tag_ids=tag_ids,
)

entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else list(entry_ids)
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else list(tag_ids)

with Session(self.engine, expire_on_commit=False) as session:
try:
for tag_id in tag_ids_:
for entry_id in entry_ids_:
tag_entry = session.scalars(
select(TagEntry).where(
and_(
TagEntry.tag_id == tag_id,
TagEntry.entry_id == entry_id,
)
)
).first()
if tag_entry:
session.delete(tag_entry)
session.flush()
session.commit()
return True
except IntegrityError as e:
logger.error(e)
session.rollback()
return False
for tags_sub_list in [
tag_ids_[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(tag_ids_), MAX_SQL_VARIABLES // 2)
]:
for entries_sub_list in [
entry_ids_[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(entry_ids_), MAX_SQL_VARIABLES // 2)
]:
stmt = delete(TagEntry).where(
and_(
TagEntry.tag_id.in_(tags_sub_list),
TagEntry.entry_id.in_(entries_sub_list),
)
)
session.execute(stmt)
session.commit()

def add_color(self, color_group: TagColorGroup) -> TagColorGroup | None:
with Session(self.engine, expire_on_commit=False) as session:
Expand Down