diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index 9636911b2e9..5e961e2773a 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -23,6 +23,7 @@ jobs: close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue." days-before-pr-stale: -1 days-before-pr-close: -1 + only-labels: "bug" exempt-issue-labels: "Active Issue" repo-token: ${{ secrets.GITHUB_TOKEN }} operations-per-run: 500 diff --git a/invokeai/app/api/auth_dependencies.py b/invokeai/app/api/auth_dependencies.py index f5537890b63..a7b01931929 100644 --- a/invokeai/app/api/auth_dependencies.py +++ b/invokeai/app/api/auth_dependencies.py @@ -7,6 +7,9 @@ from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.auth.token_service import TokenData, verify_token +from invokeai.backend.util.logging import logging + +logger = logging.getLogger(__name__) # HTTP Bearer token security scheme security = HTTPBearer(auto_error=False) @@ -61,6 +64,45 @@ async def get_current_user( return token_data +async def get_current_user_or_default( + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], +) -> TokenData: + """Get current authenticated user from Bearer token, or return a default system user if not authenticated. + + This dependency is useful for endpoints that should work in both authenticated and non-authenticated contexts. + In single-user mode or when authentication is not provided, it returns a TokenData for the 'system' user. + + Args: + credentials: The HTTP authorization credentials containing the Bearer token + + Returns: + TokenData containing user information from the token, or system user if no credentials + """ + if credentials is None: + # Return system user for unauthenticated requests (single-user mode or backwards compatibility) + logger.debug("No authentication credentials provided, using system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + token = credentials.credentials + token_data = verify_token(token) + + if token_data is None: + # Invalid token - still fall back to system user for backwards compatibility + logger.warning("Invalid or expired token provided, falling back to system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + # Verify user still exists and is active + user_service = ApiDependencies.invoker.services.users + user = user_service.get(token_data.user_id) + + if user is None or not user.is_active: + # User doesn't exist or is inactive - fall back to system user + logger.warning(f"User {token_data.user_id} does not exist or is inactive, falling back to system user") + return TokenData(user_id="system", email="system@system.invokeai", is_admin=False) + + return token_data + + async def require_admin( current_user: Annotated[TokenData, Depends(get_current_user)], ) -> TokenData: @@ -82,4 +124,5 @@ async def require_admin( # Type aliases for convenient use in route dependencies CurrentUser = Annotated[TokenData, Depends(get_current_user)] +CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)] AdminUser = Annotated[TokenData, Depends(require_admin)] diff --git a/invokeai/app/api/routers/client_state.py b/invokeai/app/api/routers/client_state.py index 188225760c7..2e34ea9fe6b 100644 --- a/invokeai/app/api/routers/client_state.py +++ b/invokeai/app/api/routers/client_state.py @@ -1,6 +1,7 @@ from fastapi import Body, HTTPException, Path, Query from fastapi.routing import APIRouter +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend.util.logging import logging @@ -13,15 +14,16 @@ response_model=str | None, ) async def get_client_state_by_key( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), key: str = Query(..., description="Key to get"), ) -> str | None: - """Gets the client state""" + """Gets the client state for the current user (or system user if not authenticated)""" try: - return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key) + return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key) except Exception as e: logging.error(f"Error getting client state: {e}") - raise HTTPException(status_code=500, detail="Error setting client state") + raise HTTPException(status_code=500, detail="Error getting client state") @client_state_router.post( @@ -30,13 +32,14 @@ async def get_client_state_by_key( response_model=str, ) async def set_client_state( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), key: str = Query(..., description="Key to set"), value: str = Body(..., description="Stringified value to set"), ) -> str: - """Sets the client state""" + """Sets the client state for the current user (or system user if not authenticated)""" try: - return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value) + return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value) except Exception as e: logging.error(f"Error setting client state: {e}") raise HTTPException(status_code=500, detail="Error setting client state") @@ -48,11 +51,12 @@ async def set_client_state( responses={204: {"description": "Client state deleted"}}, ) async def delete_client_state( - queue_id: str = Path(description="The queue id to perform this operation on"), + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), ) -> None: - """Deletes the client state""" + """Deletes the client state for the current user (or system user if not authenticated)""" try: - ApiDependencies.invoker.services.client_state_persistence.delete(queue_id) + ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id) except Exception as e: logging.error(f"Error deleting client state: {e}") raise HTTPException(status_code=500, detail="Error deleting client state") diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 222edc7959f..0d955ffc1fe 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -408,11 +408,12 @@ async def get_next_queue_item( }, ) async def get_queue_status( + current_user: CurrentUser, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: """Gets the status of the session queue""" try: - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id) + queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py index 193561ef898..99ad71bc8b7 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py @@ -4,15 +4,16 @@ class ClientStatePersistenceABC(ABC): """ Base class for client persistence implementations. - This class defines the interface for persisting client data. + This class defines the interface for persisting client data per user. """ @abstractmethod - def set_by_key(self, queue_id: str, key: str, value: str) -> str: + def set_by_key(self, user_id: str, key: str, value: str) -> str: """ Set a key-value pair for the client. Args: + user_id (str): The user ID to set state for. key (str): The key to set. value (str): The value to set for the key. @@ -22,11 +23,12 @@ def set_by_key(self, queue_id: str, key: str, value: str) -> str: pass @abstractmethod - def get_by_key(self, queue_id: str, key: str) -> str | None: + def get_by_key(self, user_id: str, key: str) -> str | None: """ Get the value for a specific key of the client. Args: + user_id (str): The user ID to get state for. key (str): The key to retrieve the value for. Returns: @@ -35,8 +37,11 @@ def get_by_key(self, queue_id: str, key: str) -> str | None: pass @abstractmethod - def delete(self, queue_id: str) -> None: + def delete(self, user_id: str) -> None: """ - Delete all client state. + Delete all client state for a user. + + Args: + user_id (str): The user ID to delete state for. """ pass diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 36f22d96760..643db306857 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -1,5 +1,3 @@ -import json - from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -7,59 +5,51 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC): """ - Base class for client persistence implementations. - This class defines the interface for persisting client data. + SQLite implementation for client state persistence. + This class stores client state data per user to prevent data leakage between users. """ def __init__(self, db: SqliteDatabase) -> None: super().__init__() self._db = db - self._default_row_id = 1 def start(self, invoker: Invoker) -> None: self._invoker = invoker - def _get(self) -> dict[str, str] | None: + def set_by_key(self, user_id: str, key: str, value: str) -> str: with self._db.transaction() as cursor: cursor.execute( - f""" - SELECT data FROM client_state - WHERE id = {self._default_row_id} """ + INSERT INTO client_state (user_id, key, value) + VALUES (?, ?, ?) + ON CONFLICT(user_id, key) DO UPDATE + SET value = excluded.value; + """, + (user_id, key, value), ) - row = cursor.fetchone() - if row is None: - return None - return json.loads(row[0]) - def set_by_key(self, queue_id: str, key: str, value: str) -> str: - state = self._get() or {} - state.update({key: value}) + return value + def get_by_key(self, user_id: str, key: str) -> str | None: with self._db.transaction() as cursor: cursor.execute( - f""" - INSERT INTO client_state (id, data) - VALUES ({self._default_row_id}, ?) - ON CONFLICT(id) DO UPDATE - SET data = excluded.data; + """ + SELECT value FROM client_state + WHERE user_id = ? AND key = ? """, - (json.dumps(state),), + (user_id, key), ) + row = cursor.fetchone() + if row is None: + return None + return row[0] - return value - - def get_by_key(self, queue_id: str, key: str) -> str | None: - state = self._get() - if state is None: - return None - return state.get(key, None) - - def delete(self, queue_id: str) -> None: + def delete(self, user_id: str) -> None: with self._db.transaction() as cursor: cursor.execute( - f""" - DELETE FROM client_state - WHERE id = {self._default_row_id} """ + DELETE FROM client_state + WHERE user_id = ? + """, + (user_id,), ) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 5232dc9c76e..42ececa2950 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -73,8 +73,8 @@ def is_full(self, queue_id: str) -> IsFullResult: pass @abstractmethod - def get_queue_status(self, queue_id: str) -> SessionQueueStatus: - """Gets the status of the queue""" + def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: + """Gets the status of the queue. If user_id is provided, also includes user-specific counts.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 09820fe6217..58544422119 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -304,6 +304,12 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") + user_pending: Optional[int] = Field( + default=None, description="Number of queue items with status 'pending' for the current user" + ) + user_in_progress: Optional[int] = Field( + default=None, description="Number of queue items with status 'in_progress' for the current user" + ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index aa5ce689b40..9e92ea6d3b5 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -773,8 +773,9 @@ def get_queue_item_ids( return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids)) - def get_queue_status(self, queue_id: str) -> SessionQueueStatus: + def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: with self._db.transaction() as cursor: + # Get total counts cursor.execute( """--sql SELECT status, count(*) @@ -786,9 +787,32 @@ def get_queue_status(self, queue_id: str) -> SessionQueueStatus: ) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + # Get user-specific counts if user_id is provided (using a single query with CASE) + user_counts_result = [] + if user_id is not None: + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? AND user_id = ? + GROUP BY status + """, + (queue_id, user_id), + ) + user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} + + # Process user-specific counts if available + user_pending = None + user_in_progress = None + if user_id is not None: + user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} + user_pending = user_counts.get("pending", 0) + user_in_progress = user_counts.get("in_progress", 0) + return SessionQueueStatus( queue_id=queue_id, item_id=current_item.item_id if current_item else None, @@ -800,6 +824,8 @@ def get_queue_status(self, queue_id: str) -> SessionQueueStatus: failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, + user_pending=user_pending, + user_in_progress=user_in_progress, ) def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 54a0450084a..ecf769a9cf4 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -28,6 +28,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -73,6 +74,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_23(app_config=config, logger=logger)) migrator.register_migration(build_migration_24(app_config=config, logger=logger)) migrator.register_migration(build_migration_25()) + migrator.register_migration(build_migration_26()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py new file mode 100644 index 00000000000..8f37404a81b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_26.py @@ -0,0 +1,120 @@ +"""Migration 26: Add user_id to client_state table for multi-user support. + +This migration updates the client_state table to support per-user state isolation: +- Drops the single-row constraint (CHECK(id = 1)) +- Adds user_id column +- Creates unique constraint on (user_id, key) pairs +- Migrates existing data to 'system' user +""" + +import json +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration26Callback: + """Migration to add per-user client state support.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_client_state_table(cursor) + + def _update_client_state_table(self, cursor: sqlite3.Cursor) -> None: + """Restructure client_state table to support per-user storage.""" + # Check if client_state table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='client_state';") + if cursor.fetchone() is None: + # Table doesn't exist, create it with the new schema + cursor.execute( + """ + CREATE TABLE client_state ( + user_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP), + PRIMARY KEY (user_id, key), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);") + cursor.execute( + """ + CREATE TRIGGER tg_client_state_updated_at + AFTER UPDATE ON client_state + FOR EACH ROW + BEGIN + UPDATE client_state + SET updated_at = CURRENT_TIMESTAMP + WHERE user_id = OLD.user_id AND key = OLD.key; + END; + """ + ) + return + + # Table exists with old schema - migrate it + # Get existing data + cursor.execute("SELECT data FROM client_state WHERE id = 1;") + row = cursor.fetchone() + existing_data = {} + if row is not None: + try: + existing_data = json.loads(row[0]) + except (json.JSONDecodeError, TypeError): + # If data is corrupt, just start fresh + pass + + # Drop the old table + cursor.execute("DROP TABLE IF EXISTS client_state;") + + # Create new table with per-user schema + cursor.execute( + """ + CREATE TABLE client_state ( + user_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP), + PRIMARY KEY (user_id, key), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ); + """ + ) + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);") + + cursor.execute( + """ + CREATE TRIGGER tg_client_state_updated_at + AFTER UPDATE ON client_state + FOR EACH ROW + BEGIN + UPDATE client_state + SET updated_at = CURRENT_TIMESTAMP + WHERE user_id = OLD.user_id AND key = OLD.key; + END; + """ + ) + + # Migrate existing data to 'system' user + # The 'system' user is created by migration 25, so it's guaranteed to exist at this point + for key, value in existing_data.items(): + cursor.execute( + """ + INSERT INTO client_state (user_id, key, value) + VALUES ('system', ?, ?); + """, + (key, value), + ) + + +def build_migration_26() -> Migration: + """Builds the migration object for migrating from version 25 to version 26. + + This migration adds per-user client state support to prevent data leakage between users. + """ + return Migration( + from_version=25, + to_version=26, + callback=Migration26Callback(), + ) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 1be662f85e8..8c27ccac11c 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -49842,6 +49842,36 @@ } ], "description": "The workflow associated with this queue item" + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who created this queue item", + "default": "system" + }, + "user_display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "User Display Name", + "description": "The display name of the user who created this queue item, if available" + }, + "user_email": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "User Email", + "description": "The email of the user who created this queue item, if available" } }, "type": "object", @@ -49932,6 +49962,30 @@ "type": "integer", "title": "Total", "description": "Total number of queue items" + }, + "user_pending": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "User Pending", + "description": "Number of queue items with status 'pending' for the current user" + }, + "user_in_progress": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "User In Progress", + "description": "Number of queue items with status 'in_progress' for the current user" } }, "type": "object", @@ -56450,6 +56504,45 @@ "output": { "$ref": "#/components/schemas/ZImageConditioningOutput" } + }, + "UserDTO": { + "type": "object", + "required": ["user_id", "email", "is_admin", "is_active"], + "properties": { + "user_id": { + "type": "string", + "title": "User Id", + "description": "The user ID" + }, + "email": { + "type": "string", + "title": "Email", + "description": "The user email" + }, + "display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Display Name", + "description": "The user display name" + }, + "is_admin": { + "type": "boolean", + "title": "Is Admin", + "description": "Whether the user is an admin" + }, + "is_active": { + "type": "boolean", + "title": "Is Active", + "description": "Whether the user is active" + } + }, + "title": "UserDTO" } } } diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index c9cd365a823..9a58239b875 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -2763,8 +2763,7 @@ "items": [ "Gestione modelli: se non è possibile identificare un modello durante l'installazione, ora è possibile selezionare manualmente il tipo di modello e l'architettura.", "Interno: sistema di identificazione dei modelli migliorato, che semplifica per i collaboratori l'aggiunta del supporto per nuovi modelli.", - "Strumento di ritaglio per immagini di riferimento", - "Interfaccia utente migliorata della scheda Gestione modelli" + "Strumento di ritaglio per immagini di riferimento" ], "watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente" }, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts index 9e67770b436..fdb25b37d2c 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts @@ -68,10 +68,26 @@ const getIdbKey = (key: string) => { return `${IDB_STORAGE_PREFIX}${key}`; }; +// Helper to get auth headers for client_state requests +const getAuthHeaders = (): Record => { + const headers: Record = {}; + // Safe access to localStorage (not available in Node.js test environment) + if (typeof window !== 'undefined' && window.localStorage) { + const token = localStorage.getItem('auth_token'); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } + } + return headers; +}; + const getItem = async (key: string) => { try { const url = getUrl('get_by_key', key); - const res = await fetch(url, { method: 'GET' }); + const res = await fetch(url, { + method: 'GET', + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } @@ -130,7 +146,11 @@ const setItem = async (key: string, value: string) => { } log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`); const url = getUrl('set_by_key', key); - const res = await fetch(url, { method: 'POST', body: value }); + const res = await fetch(url, { + method: 'POST', + body: value, + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } @@ -158,7 +178,10 @@ export const clearStorage = async () => { try { persistRefCount++; const url = getUrl('delete'); - const res = await fetch(url, { method: 'POST' }); + const res = await fetch(url, { + method: 'POST', + headers: getAuthHeaders(), + }); if (!res.ok) { throw new Error(`Response status: ${res.status}`); } diff --git a/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx b/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx index 19ccf0949aa..e62b1289d06 100644 --- a/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx +++ b/invokeai/frontend/web/src/features/auth/components/LoginPage.tsx @@ -52,13 +52,14 @@ export const LoginPage = memo(() => { is_active: result.user.is_active || true, }; dispatch(setCredentials({ token: result.token, user })); - // Navigate to main app after successful login - navigate('/app', { replace: true }); + // Force a page reload to ensure all user-specific state is loaded from server + // This is important for multiuser isolation to prevent state leakage + window.location.href = '/app'; } catch { // Error is handled by RTK Query and displayed via error state } }, - [email, password, rememberMe, login, dispatch, navigate] + [email, password, rememberMe, login, dispatch] ); const handleEmailChange = useCallback((e: ChangeEvent) => { diff --git a/invokeai/frontend/web/src/features/auth/store/authSlice.ts b/invokeai/frontend/web/src/features/auth/store/authSlice.ts index bcf932ca32d..6ac65ef03ce 100644 --- a/invokeai/frontend/web/src/features/auth/store/authSlice.ts +++ b/invokeai/frontend/web/src/features/auth/store/authSlice.ts @@ -21,9 +21,17 @@ const zAuthState = z.object({ type User = z.infer; type AuthState = z.infer; +// Helper to safely access localStorage (not available in test environment) +const getStoredAuthToken = (): string | null => { + if (typeof window !== 'undefined' && window.localStorage) { + return localStorage.getItem('auth_token'); + } + return null; +}; + const initialState: AuthState = { - isAuthenticated: !!localStorage.getItem('auth_token'), - token: localStorage.getItem('auth_token'), + isAuthenticated: !!getStoredAuthToken(), + token: getStoredAuthToken(), user: null, isLoading: false, }; @@ -38,13 +46,17 @@ const authSlice = createSlice({ state.token = action.payload.token; state.user = action.payload.user; state.isAuthenticated = true; - localStorage.setItem('auth_token', action.payload.token); + if (typeof window !== 'undefined' && window.localStorage) { + localStorage.setItem('auth_token', action.payload.token); + } }, logout: (state) => { state.token = null; state.user = null; state.isAuthenticated = false; - localStorage.removeItem('auth_token'); + if (typeof window !== 'undefined' && window.localStorage) { + localStorage.removeItem('auth_token'); + } }, setLoading: (state, action: PayloadAction) => { state.isLoading = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 0190aba602b..b4b328704e0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone'; import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; import { isPlainObject } from 'es-toolkit'; import { clamp } from 'es-toolkit/compat'; +import { logout } from 'features/auth/store/authSlice'; import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types'; import { ASPECT_RATIO_MAP, @@ -401,6 +402,12 @@ const slice = createSlice({ }, paramsReset: (state) => resetState(state), }, + extraReducers(builder) { + // Reset params state on logout to prevent user data leakage when switching users + builder.addCase(logout, () => { + return getInitialParamsState(); + }); + }, }); const applyClipSkip = (state: { clipSkip: number }, model: ParameterModel | null, clipSkip: number) => { diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx index 5093f89d573..3417488b09e 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx @@ -1,20 +1,66 @@ import { Badge, Portal } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectIsAuthenticated } from 'features/auth/store/authSlice'; import type { RefObject } from 'react'; -import { memo, useEffect, useState } from 'react'; +import { memo, useEffect, useMemo, useState } from 'react'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; +import type { components } from 'services/api/schema'; type Props = { targetRef: RefObject; }; +type SessionQueueStatus = components['schemas']['SessionQueueStatus']; + +/** + * Determines if user-specific queue counts are available. + */ +const hasUserCounts = (queueData: SessionQueueStatus): boolean => { + return ( + queueData.user_pending !== undefined && + queueData.user_pending !== null && + queueData.user_in_progress !== undefined && + queueData.user_in_progress !== null + ); +}; + +/** + * Calculates the appropriate badge text based on queue status and authentication state. + * Returns null if badge should be hidden. + */ +const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => { + if (!queueData) { + return null; + } + + const totalPending = queueData.pending + queueData.in_progress; + + // Hide badge if there are no pending jobs + if (totalPending === 0) { + return null; + } + + // In multiuser mode (authenticated user), show "X/Y" format where X is user's jobs and Y is total jobs + if (isAuthenticated && hasUserCounts(queueData)) { + const userPending = queueData.user_pending! + queueData.user_in_progress!; + return `${userPending}/${totalPending}`; + } + + // In single-user mode or when user counts aren't available, show total count only + return totalPending.toString(); +}; + export const QueueCountBadge = memo(({ targetRef }: Props) => { const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null); - const { queueSize } = useGetQueueStatusQuery(undefined, { + const isAuthenticated = useAppSelector(selectIsAuthenticated); + const { queueData } = useGetQueueStatusQuery(undefined, { selectFromResult: (res) => ({ - queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0, + queueData: res.data?.queue, }), }); + const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]); + useEffect(() => { if (!targetRef.current) { return; @@ -57,7 +103,7 @@ export const QueueCountBadge = memo(({ targetRef }: Props) => { }; }, [targetRef]); - if (queueSize === 0) { + if (!badgeText) { return null; } if (!badgePos) { @@ -75,7 +121,7 @@ export const QueueCountBadge = memo(({ targetRef }: Props) => { shadow="dark-lg" userSelect="none" > - {queueSize} + {badgeText} ); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 24323c92dc5..e69970d8399 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2053,7 +2053,7 @@ export type paths = { }; /** * Get Client State By Key - * @description Gets the client state + * @description Gets the client state for the current user (or system user if not authenticated) */ get: operations["get_client_state_by_key"]; put?: never; @@ -2075,7 +2075,7 @@ export type paths = { put?: never; /** * Set Client State - * @description Sets the client state + * @description Sets the client state for the current user (or system user if not authenticated) */ post: operations["set_client_state"]; delete?: never; @@ -2095,7 +2095,7 @@ export type paths = { put?: never; /** * Delete Client State - * @description Deletes the client state + * @description Deletes the client state for the current user (or system user if not authenticated) */ post: operations["delete_client_state"]; delete?: never; @@ -22549,6 +22549,16 @@ export type components = { * @description Total number of queue items */ total: number; + /** + * User Pending + * @description Number of queue items with status 'pending' for the current user + */ + user_pending?: number | null; + /** + * User In Progress + * @description Number of queue items with status 'in_progress' for the current user + */ + user_in_progress?: number | null; }; /** * SetupRequest @@ -30860,7 +30870,7 @@ export interface operations { }; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; @@ -30895,7 +30905,7 @@ export interface operations { }; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; @@ -30931,7 +30941,7 @@ export interface operations { query?: never; header?: never; path: { - /** @description The queue id to perform this operation on */ + /** @description The queue id (ignored, kept for backwards compatibility) */ queue_id: string; }; cookie?: never; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 945eebae040..74069b084ae 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -387,10 +387,12 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis ); // Invalidate caches for things we cannot easily update + // Invalidate SessionQueueStatus to refetch with user-specific counts const tagsToInvalidate: ApiTagDescription[] = [ 'CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus', + 'SessionQueueStatus', 'SessionQueueItemIdList', { type: 'SessionQueueItem', id: item_id }, { type: 'SessionQueueItem', id: LIST_TAG }, @@ -401,16 +403,6 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis tagsToInvalidate.push({ type: 'QueueCountsByDestination', id: destination }); } dispatch(queueApi.util.invalidateTags(tagsToInvalidate)); - dispatch( - queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => { - draft.queue = data.queue_status; - }) - ); - dispatch( - queueApi.util.updateQueryData('getBatchStatus', { batch_id: data.batch_id }, (draft) => { - Object.assign(draft, data.batch_status); - }) - ); if (status === 'in_progress') { forEach($nodeExecutionStates.get(), (nes) => { @@ -463,6 +455,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug({ data }, 'Batch enqueued'); dispatch( queueApi.util.invalidateTags([ + 'SessionQueueStatus', 'CurrentSessionQueueItem', 'NextSessionQueueItem', 'QueueCountsByDestination', diff --git a/tests/app/routers/test_client_state_multiuser.py b/tests/app/routers/test_client_state_multiuser.py new file mode 100644 index 00000000000..2b67e8c0165 --- /dev/null +++ b/tests/app/routers/test_client_state_multiuser.py @@ -0,0 +1,296 @@ +"""Tests for multiuser client state functionality.""" + +from typing import Any + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.api_app import app +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.users.users_common import UserCreateRequest + + +@pytest.fixture +def client(): + """Create a test client.""" + return TestClient(app) + + +class MockApiDependencies(ApiDependencies): + """Mock API dependencies for testing.""" + + invoker: Invoker + + def __init__(self, invoker: Invoker) -> None: + self.invoker = invoker + + +def setup_test_user( + mock_invoker: Invoker, email: str, display_name: str, password: str = "TestPass123", is_admin: bool = False +) -> str: + """Helper to create a test user and return user_id.""" + user_service = mock_invoker.services.users + user_data = UserCreateRequest( + email=email, + display_name=display_name, + password=password, + is_admin=is_admin, + ) + user = user_service.create(user_data) + return user.user_id + + +def get_user_token(client: TestClient, email: str, password: str = "TestPass123") -> str: + """Helper to login and get a user token.""" + response = client.post( + "/api/v1/auth/login", + json={ + "email": email, + "password": password, + "remember_me": False, + }, + ) + assert response.status_code == 200 + return response.json()["token"] + + +@pytest.fixture +def admin_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient): + """Get an admin token for testing.""" + # Mock ApiDependencies for auth and client_state routers + monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Create admin user + setup_test_user(mock_invoker, "admin@test.com", "Admin User", is_admin=True) + + return get_user_token(client, "admin@test.com") + + +@pytest.fixture +def user1_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Get a token for test user 1.""" + # Create a regular user + setup_test_user(mock_invoker, "user1@test.com", "User One", is_admin=False) + + return get_user_token(client, "user1@test.com") + + +@pytest.fixture +def user2_token(monkeypatch: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Get a token for test user 2.""" + # Create another regular user + setup_test_user(mock_invoker, "user2@test.com", "User Two", is_admin=False) + + return get_user_token(client, "user2@test.com") + + +def test_get_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that getting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set a value for the system user directly + mock_invoker.services.client_state_persistence.set_by_key("system", "test_key", "system_value") + + # Get without authentication - should return system user's value + response = client.get("/api/v1/client_state/default/get_by_key?key=test_key") + assert response.status_code == status.HTTP_200_OK + assert response.json() == "system_value" + + +def test_set_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that setting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set without authentication - should set for system user + response = client.post( + "/api/v1/client_state/default/set_by_key?key=test_key", + json="unauthenticated_value", + ) + assert response.status_code == status.HTTP_200_OK + + # Verify it was set for system user + value = mock_invoker.services.client_state_persistence.get_by_key("system", "test_key") + assert value == "unauthenticated_value" + + +def test_delete_client_state_without_auth_uses_system_user(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that deleting client state without authentication uses the system user.""" + # Mock ApiDependencies + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set a value for system user + mock_invoker.services.client_state_persistence.set_by_key("system", "test_key", "system_value") + + # Delete without authentication - should delete system user's data + response = client.post("/api/v1/client_state/default/delete") + assert response.status_code == status.HTTP_200_OK + + # Verify it was deleted for system user + value = mock_invoker.services.client_state_persistence.get_by_key("system", "test_key") + assert value is None + + +def test_set_and_get_client_state(client: TestClient, admin_token: str): + """Test that authenticated users can set and get their client state.""" + # Set a value + set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=test_key", + json="test_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert set_response.status_code == status.HTTP_200_OK + assert set_response.json() == "test_value" + + # Get the value back + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=test_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == "test_value" + + +def test_client_state_isolation_between_users(client: TestClient, user1_token: str, user2_token: str): + """Test that client state is isolated between different users.""" + # User 1 sets a value + user1_set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=shared_key", + json="user1_value", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert user1_set_response.status_code == status.HTTP_200_OK + + # User 2 sets a different value for the same key + user2_set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=shared_key", + json="user2_value", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert user2_set_response.status_code == status.HTTP_200_OK + + # User 1 should still see their own value + user1_get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=shared_key", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert user1_get_response.status_code == status.HTTP_200_OK + assert user1_get_response.json() == "user1_value" + + # User 2 should see their own value + user2_get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=shared_key", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert user2_get_response.status_code == status.HTTP_200_OK + assert user2_get_response.json() == "user2_value" + + +def test_get_nonexistent_key_returns_null(client: TestClient, admin_token: str): + """Test that getting a nonexistent key returns null.""" + response = client.get( + "/api/v1/client_state/default/get_by_key?key=nonexistent_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + +def test_delete_client_state(client: TestClient, admin_token: str): + """Test that users can delete their own client state.""" + # Set some values + client.post( + "/api/v1/client_state/default/set_by_key?key=key1", + json="value1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + client.post( + "/api/v1/client_state/default/set_by_key?key=key2", + json="value2", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Verify values exist + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() == "value1" + + # Delete all client state + delete_response = client.post( + "/api/v1/client_state/default/delete", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert delete_response.status_code == status.HTTP_200_OK + + # Verify values are gone + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key1", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=key2", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + +def test_update_existing_key(client: TestClient, admin_token: str): + """Test that updating an existing key works correctly.""" + # Set initial value + client.post( + "/api/v1/client_state/default/set_by_key?key=update_key", + json="initial_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Update the value + update_response = client.post( + "/api/v1/client_state/default/set_by_key?key=update_key", + json="updated_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert update_response.status_code == status.HTTP_200_OK + + # Verify the updated value + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=update_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == "updated_value" + + +def test_complex_json_values(client: TestClient, admin_token: str): + """Test that complex JSON values can be stored and retrieved.""" + import json + + complex_dict = {"params": {"model": "test-model", "steps": 50}, "prompt": "a beautiful landscape"} + complex_value = json.dumps(complex_dict) + + # Set complex value + set_response = client.post( + "/api/v1/client_state/default/set_by_key?key=complex_key", + json=complex_value, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert set_response.status_code == status.HTTP_200_OK + + # Get it back + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=complex_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.status_code == status.HTTP_200_OK + assert get_response.json() == complex_value diff --git a/tests/conftest.py b/tests/conftest.py index 84e66b0501d..980a99611ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.boards.boards_default import BoardService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage from invokeai.app.services.images.images_default import ImageService @@ -64,7 +65,7 @@ def mock_services() -> InvocationServices: workflow_thumbnails=None, # type: ignore model_relationship_records=None, # type: ignore model_relationships=None, # type: ignore - client_state_persistence=None, # type: ignore + client_state_persistence=ClientStatePersistenceSqlite(db=db), users=UserService(db), )