Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions invokeai/app/api/routers/session_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def clear(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely. If there's a currently-executing item, users can only cancel it if they own it or are an admin."""
"""Clears the queue entirely. Admin users clear all items; non-admin users only clear their own items. If there's a currently-executing item, users can only cancel it if they own it or are an admin."""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
Expand All @@ -338,7 +338,9 @@ async def clear(
status_code=403, detail="You do not have permission to cancel the currently executing queue item"
)
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
# Admin users can clear all items, non-admin users can only clear their own
user_id = None if current_user.is_admin else current_user.user_id
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id, user_id=user_id)
return clear_result
except HTTPException:
raise
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/services/session_queue/session_queue_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
pass

@abstractmethod
def clear(self, queue_id: str) -> ClearResult:
"""Deletes all session queue items"""
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
"""Deletes all session queue items. If user_id is provided, only clears items owned by that user."""
pass

@abstractmethod
Expand Down
22 changes: 15 additions & 7 deletions invokeai/app/services/session_queue/session_queue_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,24 +292,32 @@ def is_full(self, queue_id: str) -> IsFullResult:
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)

def clear(self, queue_id: str) -> ClearResult:
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
with self._db.transaction() as cursor:
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE queue_id = ?
{user_filter}
"""
params: list[str] = [queue_id]
if user_id is not None:
params.append(user_id)
cursor.execute(
"""--sql
f"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE queue_id = ?
{where}
""",
(queue_id,),
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
f"""--sql
DELETE
FROM session_queue
WHERE queue_id = ?
{where}
""",
(queue_id,),
tuple(params),
)
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ export type paths = {
get?: never;
/**
* Clear
* @description Clears the queue entirely. If there's a currently-executing item, users can only cancel it if they own it or are an admin.
* @description Clears the queue entirely. Admin users clear all items; non-admin users only clear their own items. If there's a currently-executing item, users can only cancel it if they own it or are an admin.
*/
put: operations["clear"];
post?: never;
Expand Down
106 changes: 106 additions & 0 deletions tests/app/services/session_queue/test_session_queue_clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Tests for session queue clear() user_id scoping."""

import uuid

import pytest

from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue


@pytest.fixture
def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue:
"""Create a SqliteSessionQueue backed by the mock invoker's in-memory database."""
db = mock_invoker.services.board_records._db
queue = SqliteSessionQueue(db=db)
queue.start(mock_invoker)
return queue


def _insert_queue_item(session_queue: SqliteSessionQueue, queue_id: str, user_id: str) -> None:
"""Directly insert a minimal queue item for the given user."""
session_id = str(uuid.uuid4())
batch_id = str(uuid.uuid4())
with session_queue._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(queue_id, "{}", session_id, batch_id, None, 0, None, None, None, None, user_id),
)


def _count_items(session_queue: SqliteSessionQueue, queue_id: str, user_id: str | None = None) -> int:
"""Count items in the queue, optionally filtered by user_id."""
with session_queue._db.transaction() as cursor:
if user_id is not None:
cursor.execute(
"SELECT COUNT(*) FROM session_queue WHERE queue_id = ? AND user_id = ?",
(queue_id, user_id),
)
else:
cursor.execute(
"SELECT COUNT(*) FROM session_queue WHERE queue_id = ?",
(queue_id,),
)
return cursor.fetchone()[0]


def test_clear_with_user_id_only_deletes_own_items(session_queue: SqliteSessionQueue) -> None:
"""Non-admin clear (user_id provided) should only remove that user's items."""
queue_id = "default"
user_a = "user_a"
user_b = "user_b"

_insert_queue_item(session_queue, queue_id, user_a)
_insert_queue_item(session_queue, queue_id, user_a)
_insert_queue_item(session_queue, queue_id, user_b)

result = session_queue.clear(queue_id, user_id=user_a)

assert result.deleted == 2
assert _count_items(session_queue, queue_id, user_a) == 0
assert _count_items(session_queue, queue_id, user_b) == 1


def test_clear_without_user_id_deletes_all_items(session_queue: SqliteSessionQueue) -> None:
"""Admin clear (no user_id) should remove all items in the queue."""
queue_id = "default"

_insert_queue_item(session_queue, queue_id, "user_a")
_insert_queue_item(session_queue, queue_id, "user_b")
_insert_queue_item(session_queue, queue_id, "user_c")

result = session_queue.clear(queue_id)

assert result.deleted == 3
assert _count_items(session_queue, queue_id) == 0


def test_clear_with_user_id_does_not_affect_other_queues(session_queue: SqliteSessionQueue) -> None:
"""Clearing one queue should not affect items in another queue."""
queue_a = "queue_a"
queue_b = "queue_b"
user_id = "user_x"

_insert_queue_item(session_queue, queue_a, user_id)
_insert_queue_item(session_queue, queue_b, user_id)

result = session_queue.clear(queue_a, user_id=user_id)

assert result.deleted == 1
assert _count_items(session_queue, queue_a) == 0
assert _count_items(session_queue, queue_b) == 1


def test_clear_returns_zero_when_no_matching_items(session_queue: SqliteSessionQueue) -> None:
"""Clear should return 0 deleted when there are no items for the given user."""
queue_id = "default"

_insert_queue_item(session_queue, queue_id, "user_b")

result = session_queue.clear(queue_id, user_id="user_a")

assert result.deleted == 0
assert _count_items(session_queue, queue_id) == 1