Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/close-inactive-issues.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1
days-before-pr-close: -1
only-labels: "bug"
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500
43 changes: 43 additions & 0 deletions invokeai/app/api/auth_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, verify_token
from invokeai.backend.util.logging import logging

logger = logging.getLogger(__name__)

# HTTP Bearer token security scheme
security = HTTPBearer(auto_error=False)
Expand Down Expand Up @@ -61,6 +64,45 @@ async def get_current_user(
return token_data


async def get_current_user_or_default(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
) -> TokenData:
"""Get current authenticated user from Bearer token, or return a default system user if not authenticated.

This dependency is useful for endpoints that should work in both authenticated and non-authenticated contexts.
In single-user mode or when authentication is not provided, it returns a TokenData for the 'system' user.

Args:
credentials: The HTTP authorization credentials containing the Bearer token

Returns:
TokenData containing user information from the token, or system user if no credentials
"""
if credentials is None:
# Return system user for unauthenticated requests (single-user mode or backwards compatibility)
logger.debug("No authentication credentials provided, using system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

token = credentials.credentials
token_data = verify_token(token)

if token_data is None:
# Invalid token - still fall back to system user for backwards compatibility
logger.warning("Invalid or expired token provided, falling back to system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

# Verify user still exists and is active
user_service = ApiDependencies.invoker.services.users
user = user_service.get(token_data.user_id)

if user is None or not user.is_active:
# User doesn't exist or is inactive - fall back to system user
logger.warning(f"User {token_data.user_id} does not exist or is inactive, falling back to system user")
return TokenData(user_id="system", email="system@system.invokeai", is_admin=False)

return token_data


async def require_admin(
current_user: Annotated[TokenData, Depends(get_current_user)],
) -> TokenData:
Expand All @@ -82,4 +124,5 @@ async def require_admin(

# Type aliases for convenient use in route dependencies
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)]
AdminUser = Annotated[TokenData, Depends(require_admin)]
24 changes: 14 additions & 10 deletions invokeai/app/api/routers/client_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging

Expand All @@ -13,15 +14,16 @@
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
"""Gets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
raise HTTPException(status_code=500, detail="Error getting client state")


@client_state_router.post(
Expand All @@ -30,13 +32,14 @@ async def get_client_state_by_key(
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
"""Sets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
Expand All @@ -48,11 +51,12 @@ async def set_client_state(
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
) -> None:
"""Deletes the client state"""
"""Deletes the client state for the current user (or system user if not authenticated)"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")
3 changes: 2 additions & 1 deletion invokeai/app/api/routers/session_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,12 @@ async def get_next_queue_item(
},
)
async def get_queue_status(
current_user: CurrentUser,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
This class defines the interface for persisting client data per user.
"""

@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
"""
Set a key-value pair for the client.

Args:
user_id (str): The user ID to set state for.
key (str): The key to set.
value (str): The value to set for the key.

Expand All @@ -22,11 +23,12 @@ def set_by_key(self, queue_id: str, key: str, value: str) -> str:
pass

@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
def get_by_key(self, user_id: str, key: str) -> str | None:
"""
Get the value for a specific key of the client.

Args:
user_id (str): The user ID to get state for.
key (str): The key to retrieve the value for.

Returns:
Expand All @@ -35,8 +37,11 @@ def get_by_key(self, queue_id: str, key: str) -> str | None:
pass

@abstractmethod
def delete(self, queue_id: str) -> None:
def delete(self, user_id: str) -> None:
"""
Delete all client state.
Delete all client state for a user.

Args:
user_id (str): The user ID to delete state for.
"""
pass
Original file line number Diff line number Diff line change
@@ -1,65 +1,55 @@
import json

from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase


class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
SQLite implementation for client state persistence.
This class stores client state data per user to prevent data leakage between users.
"""

def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1

def start(self, invoker: Invoker) -> None:
self._invoker = invoker

def _get(self) -> dict[str, str] | None:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
INSERT INTO client_state (user_id, key, value)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = excluded.value;
""",
(user_id, key, value),
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])

def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
return value

def get_by_key(self, user_id: str, key: str) -> str | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
"""
SELECT value FROM client_state
WHERE user_id = ? AND key = ?
""",
(json.dumps(state),),
(user_id, key),
)
row = cursor.fetchone()
if row is None:
return None
return row[0]

return value

def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)

def delete(self, queue_id: str) -> None:
def delete(self, user_id: str) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
DELETE FROM client_state
WHERE user_id = ?
""",
(user_id,),
)
4 changes: 2 additions & 2 deletions invokeai/app/services/session_queue/session_queue_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def is_full(self, queue_id: str) -> IsFullResult:
pass

@abstractmethod
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
"""Gets the status of the queue"""
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
"""Gets the status of the queue. If user_id is provided, also includes user-specific counts."""
pass

@abstractmethod
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ class SessionQueueStatus(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
user_pending: Optional[int] = Field(
default=None, description="Number of queue items with status 'pending' for the current user"
)
user_in_progress: Optional[int] = Field(
default=None, description="Number of queue items with status 'in_progress' for the current user"
)


class SessionQueueCountsByDestination(BaseModel):
Expand Down
28 changes: 27 additions & 1 deletion invokeai/app/services/session_queue/session_queue_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,9 @@ def get_queue_item_ids(

return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))

def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
with self._db.transaction() as cursor:
# Get total counts
cursor.execute(
"""--sql
SELECT status, count(*)
Expand All @@ -786,9 +787,32 @@ def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())

# Get user-specific counts if user_id is provided (using a single query with CASE)
user_counts_result = []
if user_id is not None:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ? AND user_id = ?
GROUP BY status
""",
(queue_id, user_id),
)
user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())

current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}

# Process user-specific counts if available
user_pending = None
user_in_progress = None
if user_id is not None:
user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
user_pending = user_counts.get("pending", 0)
user_in_progress = user_counts.get("in_progress", 0)

return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
Expand All @@ -800,6 +824,8 @@ def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
user_pending=user_pending,
user_in_progress=user_in_progress,
)

def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/shared/sqlite/sqlite_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator


Expand Down Expand Up @@ -73,6 +74,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
migrator.register_migration(build_migration_25())
migrator.register_migration(build_migration_26())
migrator.run_migrations()

return db
Loading