diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index 778849927bd..e93bb8b2a9b 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -53,10 +53,14 @@ async def get_board( try: result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) - return result except Exception: raise HTTPException(status_code=404, detail="Board not found") + if not current_user.is_admin and result.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to access this board") + + return result + @boards_router.patch( "/{board_id}", @@ -75,6 +79,14 @@ async def update_board( changes: BoardChanges = Body(description="The changes to apply to the board"), ) -> BoardDTO: """Updates a board (user must have access to it)""" + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if not current_user.is_admin and board.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this board") + try: result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes) return result @@ -89,6 +101,14 @@ async def delete_board( include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False), ) -> DeleteBoardResult: """Deletes a board (user must have access to it)""" + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if not current_user.is_admin and board.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to delete this board") + try: if include_images is True: deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( @@ -155,12 +175,22 @@ async def list_boards( response_model=list[str], ) async def list_all_board_image_names( + current_user: CurrentUserOrDefault, board_id: str = Path(description="The id of the board or 'none' for uncategorized images"), categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."), is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."), ) -> list[str]: """Gets a list of images for a board""" + if board_id != "none": + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if not current_user.is_admin and board.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to access this board") + image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( board_id, categories, diff --git a/tests/app/routers/test_boards_multiuser.py b/tests/app/routers/test_boards_multiuser.py index ca42e285c6a..d5c48481567 100644 --- a/tests/app/routers/test_boards_multiuser.py +++ b/tests/app/routers/test_boards_multiuser.py @@ -1,6 +1,7 @@ """Tests for multiuser boards functionality.""" from typing import Any +from unittest.mock import MagicMock import pytest from fastapi import status @@ -12,6 +13,15 @@ from invokeai.app.services.users.users_common import UserCreateRequest +class MockApiDependencies(ApiDependencies): + """Mock API dependencies for testing.""" + + invoker: Invoker + + def __init__(self, invoker: Invoker) -> None: + self.invoker = invoker + + @pytest.fixture def setup_jwt_secret(): """Initialize JWT secret for token generation.""" @@ -27,60 +37,75 @@ def client(): return TestClient(app) -def setup_test_admin(mock_invoker: Invoker, email: str = "admin@test.com", password: str = "TestPass123") -> str: - """Helper to create a test admin user and return user_id.""" +def setup_test_user( + mock_invoker: Invoker, + email: str, + display_name: str, + password: str = "TestPass123", + is_admin: bool = False, +) -> str: + """Helper to create a test user and return user_id.""" user_service = mock_invoker.services.users user_data = UserCreateRequest( email=email, - display_name="Test Admin", + display_name=display_name, password=password, - is_admin=True, + is_admin=is_admin, ) user = user_service.create(user_data) return user.user_id +def get_user_token(client: TestClient, email: str, password: str = "TestPass123") -> str: + """Helper to login and get a user token.""" + response = client.post( + "/api/v1/auth/login", + json={"email": email, "password": password, "remember_me": False}, + ) + assert response.status_code == 200 + return response.json()["token"] + + @pytest.fixture def enable_multiuser_for_tests(monkeypatch: Any, mock_invoker: Invoker): - """Enable multiuser mode and set up ApiDependencies for testing.""" - # Enable multiuser mode + """Enable multiuser mode and patch ApiDependencies for all relevant routers.""" mock_invoker.services.configuration.multiuser = True - - # Set ApiDependencies.invoker as a class attribute - ApiDependencies.invoker = mock_invoker - + # Provide a mock board_images service so delete/image_names endpoints don't 500 + mock_board_images = MagicMock() + mock_board_images.get_all_board_image_names_for_board.return_value = [] + mock_invoker.services.board_images = mock_board_images + + mock_deps = MockApiDependencies(mock_invoker) + monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", mock_deps) + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps) + monkeypatch.setattr("invokeai.app.api.routers.boards.ApiDependencies", mock_deps) yield - # Cleanup - if hasattr(ApiDependencies, "invoker"): - delattr(ApiDependencies, "invoker") - @pytest.fixture -def admin_token(setup_jwt_secret: str, enable_multiuser_for_tests: Any, mock_invoker: Invoker, client: TestClient): - """Get an admin token for testing.""" - # Create admin user - setup_test_admin(mock_invoker, "admin@test.com", "TestPass123") +def admin_token(setup_jwt_secret: None, enable_multiuser_for_tests: Any, mock_invoker: Invoker, client: TestClient): + """Create an admin user and return a login token.""" + setup_test_user(mock_invoker, "admin@test.com", "Test Admin", is_admin=True) + return get_user_token(client, "admin@test.com") - # Login to get token - response = client.post( - "/api/v1/auth/login", - json={ - "email": "admin@test.com", - "password": "TestPass123", - "remember_me": False, - }, - ) - assert response.status_code == 200 - return response.json()["token"] + +@pytest.fixture +def user1_token(enable_multiuser_for_tests: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Create a regular user and return a login token.""" + setup_test_user(mock_invoker, "user1@test.com", "User One", is_admin=False) + return get_user_token(client, "user1@test.com") @pytest.fixture -def user1_token(admin_token): - """Get a token for test user 1.""" - # For now, we'll reuse admin token since user creation requires admin - # In a full implementation, we'd create a separate user - return admin_token +def user2_token(enable_multiuser_for_tests: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + """Create a second regular user and return a login token.""" + setup_test_user(mock_invoker, "user2@test.com", "User Two", is_admin=False) + return get_user_token(client, "user2@test.com") + + +# --------------------------------------------------------------------------- +# Basic auth requirement tests +# --------------------------------------------------------------------------- def test_create_board_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): @@ -95,6 +120,35 @@ def test_list_boards_requires_auth(enable_multiuser_for_tests: Any, client: Test assert response.status_code == status.HTTP_401_UNAUTHORIZED +def test_get_board_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): + """Test that getting a board requires authentication.""" + response = client.get("/api/v1/boards/some-board-id") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_update_board_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): + """Test that updating a board requires authentication.""" + response = client.patch("/api/v1/boards/some-board-id", json={"board_name": "New Name"}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_delete_board_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): + """Test that deleting a board requires authentication.""" + response = client.delete("/api/v1/boards/some-board-id") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_list_board_image_names_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): + """Test that listing board image names requires authentication.""" + response = client.get("/api/v1/boards/some-board-id/image_names") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +# --------------------------------------------------------------------------- +# Basic create / list tests +# --------------------------------------------------------------------------- + + def test_create_board_with_auth(client: TestClient, admin_token: str): """Test that authenticated users can create boards.""" response = client.post( @@ -123,7 +177,6 @@ def test_list_boards_with_auth(client: TestClient, admin_token: str): assert response.status_code == status.HTTP_200_OK boards = response.json() assert isinstance(boards, list) - # Should include the board we just created board_names = [b["board_name"] for b in boards] assert "Listed Board" in board_names @@ -137,8 +190,7 @@ def test_user_boards_are_isolated(client: TestClient, admin_token: str, user1_to ) assert admin_response.status_code == status.HTTP_201_CREATED - # If we had separate users, we'd verify isolation here - # For now, we'll just verify the board exists + # Admin can see their own board list_response = client.get( "/api/v1/boards/?all=true", headers={"Authorization": f"Bearer {admin_token}"}, @@ -148,6 +200,248 @@ def test_user_boards_are_isolated(client: TestClient, admin_token: str, user1_to board_names = [b["board_name"] for b in boards] assert "Admin Board" in board_names + # user1 should not see admin's board in their own listing + user1_list = client.get( + "/api/v1/boards/?all=true", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert user1_list.status_code == status.HTTP_200_OK + user1_board_names = [b["board_name"] for b in user1_list.json()] + assert "Admin Board" not in user1_board_names + + +# --------------------------------------------------------------------------- +# Ownership enforcement: get_board +# --------------------------------------------------------------------------- + + +def test_get_board_owner_succeeds(client: TestClient, user1_token: str): + """Test that the board owner can retrieve their own board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json()["board_id"] == board_id + + +def test_get_board_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str): + """Test that a non-owner cannot retrieve another user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Private+Board", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_get_board_admin_can_access_any_board(client: TestClient, admin_token: str, user1_token: str): + """Test that an admin can retrieve any user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+For+Admin", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + + +# --------------------------------------------------------------------------- +# Ownership enforcement: update_board +# --------------------------------------------------------------------------- + + +def test_update_board_owner_succeeds(client: TestClient, user1_token: str): + """Test that the board owner can update their own board.""" + create = client.post( + "/api/v1/boards/?board_name=Original+Name", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.patch( + f"/api/v1/boards/{board_id}", + json={"board_name": "Updated Name"}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["board_name"] == "Updated Name" + + +def test_update_board_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str): + """Test that a non-owner cannot update another user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+To+Update", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.patch( + f"/api/v1/boards/{board_id}", + json={"board_name": "Hijacked Name"}, + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_update_board_admin_can_update_any_board(client: TestClient, admin_token: str, user1_token: str): + """Test that an admin can update any user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+Admin+Update", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.patch( + f"/api/v1/boards/{board_id}", + json={"board_name": "Admin Updated Name"}, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["board_name"] == "Admin Updated Name" + + +# --------------------------------------------------------------------------- +# Ownership enforcement: delete_board +# --------------------------------------------------------------------------- + + +def test_delete_board_owner_succeeds(client: TestClient, user1_token: str): + """Test that the board owner can delete their own board.""" + create = client.post( + "/api/v1/boards/?board_name=Board+To+Delete", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.delete( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json()["board_id"] == board_id + + +def test_delete_board_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str): + """Test that a non-owner cannot delete another user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+To+Delete", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.delete( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_delete_board_admin_can_delete_any_board(client: TestClient, admin_token: str, user1_token: str): + """Test that an admin can delete any user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+Admin+Delete", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.delete( + f"/api/v1/boards/{board_id}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + + +# --------------------------------------------------------------------------- +# Ownership enforcement: list_all_board_image_names +# --------------------------------------------------------------------------- + + +def test_list_board_image_names_owner_succeeds(client: TestClient, user1_token: str): + """Test that the board owner can list image names for their board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Images+Board", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}/image_names", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert isinstance(response.json(), list) + + +def test_list_board_image_names_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str): + """Test that a non-owner cannot list image names for another user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Private+Images+Board", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}/image_names", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_list_board_image_names_admin_can_access_any_board(client: TestClient, admin_token: str, user1_token: str): + """Test that an admin can list image names for any user's board.""" + create = client.post( + "/api/v1/boards/?board_name=User1+Board+Admin+Images", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert create.status_code == status.HTTP_201_CREATED + board_id = create.json()["board_id"] + + response = client.get( + f"/api/v1/boards/{board_id}/image_names", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + + +def test_list_board_image_names_none_board_no_auth_check(enable_multiuser_for_tests: Any, client: TestClient): + """Test that listing image names for the 'none' board requires auth but no ownership check.""" + # The 'none' board is the uncategorized images board — no ownership check needed, + # but auth is still required in multiuser mode. + response = client.get("/api/v1/boards/none/image_names") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +# --------------------------------------------------------------------------- +# Misc tests +# --------------------------------------------------------------------------- + def test_enqueue_batch_requires_auth(enable_multiuser_for_tests: Any, client: TestClient): """Test that enqueuing a batch requires authentication."""