diff --git a/agentex-ui/.gitignore b/agentex-ui/.gitignore index 75d23b6..9e1b8b6 100644 --- a/agentex-ui/.gitignore +++ b/agentex-ui/.gitignore @@ -42,3 +42,7 @@ yarn-error.log* # typescript *.tsbuildinfo next-env.d.ts + +# pnpm +pnpm-lock.yaml +pnpm-workspace.yaml diff --git a/agentex/src/adapters/crud_store/adapter_mongodb.py b/agentex/src/adapters/crud_store/adapter_mongodb.py index 95586c2..e16f709 100644 --- a/agentex/src/adapters/crud_store/adapter_mongodb.py +++ b/agentex/src/adapters/crud_store/adapter_mongodb.py @@ -605,6 +605,102 @@ async def find_by_field( message=f"Failed to find items by field in MongoDB: {e}", detail=str(e) ) from e + async def find_by_field_with_cursor( + self, + field_name: str, + field_value: Any, + limit: int | None = None, + sort_by: dict[str, int] | None = None, + before_id: str | None = None, + after_id: str | None = None, + ) -> builtins.list[T]: + """ + Find documents by a given field with cursor-based pagination. + Maps _id to .id for each returned item. + + Args: + field_name: The field name to search by + field_value: The value to search for + limit: Optional limit on the number of documents to return + sort_by: Optional dictionary for sorting, e.g. {"created_at": -1} for descending + before_id: Get documents created before this document ID + after_id: Get documents created after this document ID + + Note: + Cursor pagination uses the created_at timestamp of the cursor document + to filter results. This provides stable pagination even when new + documents are added. + """ + try: + # Map 'id' field to '_id' for MongoDB if needed + mongo_field_name = "_id" if field_name == "id" else field_name + mongo_field_value = field_value + + # Convert id string to ObjectId if searching by _id + if mongo_field_name == "_id" and isinstance(mongo_field_value, str): + try: + mongo_field_value = ObjectId(mongo_field_value) + except Exception: + pass + + # Build base query + query: dict[str, Any] = {mongo_field_name: mongo_field_value} + + # If cursor is provided, look up the cursor document's timestamp + # Use compound comparison (created_at, _id) to handle timestamp ties + if before_id or after_id: + cursor_id = before_id or after_id + try: + cursor_object_id = ObjectId(cursor_id) + except Exception: + cursor_object_id = cursor_id + + cursor_doc = self.collection.find_one({"_id": cursor_object_id}) + if cursor_doc and "created_at" in cursor_doc: + cursor_timestamp = cursor_doc["created_at"] + if before_id: + # Get documents where: + # - created_at < cursor_timestamp, OR + # - created_at == cursor_timestamp AND _id < cursor_id (tie-breaker) + query["$or"] = [ + {"created_at": {"$lt": cursor_timestamp}}, + { + "created_at": cursor_timestamp, + "_id": {"$lt": cursor_object_id}, + }, + ] + else: # after_id + # Get documents where: + # - created_at > cursor_timestamp, OR + # - created_at == cursor_timestamp AND _id > cursor_id (tie-breaker) + query["$or"] = [ + {"created_at": {"$gt": cursor_timestamp}}, + { + "created_at": cursor_timestamp, + "_id": {"$gt": cursor_object_id}, + }, + ] + + # Create a cursor + db_cursor = self.collection.find(query) + + # Apply sorting + sort_by_items = list(sort_by.items()) if sort_by else [] + # Use ID for tiebreaking + sort_by_items.append(("_id", 1)) + db_cursor = db_cursor.sort(sort_by_items) + + # Apply limit if specified + limit = limit or DEFAULT_PAGE_LIMIT + db_cursor = db_cursor.limit(limit) + + return [self._deserialize(doc) for doc in db_cursor] + except Exception as e: + raise ServiceError( + message=f"Failed to find items by field with cursor in MongoDB: {e}", + detail=str(e), + ) from e + @retry_write_operation() async def delete_by_field(self, field_name: str, field_value: Any) -> int: """ diff --git a/agentex/src/api/routes/messages.py b/agentex/src/api/routes/messages.py index 96ad78c..ca70c5d 100644 --- a/agentex/src/api/routes/messages.py +++ b/agentex/src/api/routes/messages.py @@ -1,4 +1,7 @@ +from typing import Literal + from fastapi import APIRouter +from pydantic import Field from src.api.schemas.authorization_types import ( AgentexResourceType, @@ -14,10 +17,24 @@ from src.domain.entities.task_messages import convert_task_message_content_to_entity from src.domain.use_cases.messages_use_case import DMessageUseCase from src.utils.authorization_shortcuts import DAuthorizedBodyId, DAuthorizedQuery +from src.utils.model_utils import BaseModel +from src.utils.pagination import decode_cursor, encode_cursor router = APIRouter(prefix="/messages", tags=["Messages"]) +class PaginatedMessagesResponse(BaseModel): + """Response with cursor pagination metadata.""" + + data: list[TaskMessage] = Field(..., description="List of messages") + next_cursor: str | None = Field( + None, description="Cursor for fetching the next page of older messages" + ) + has_more: bool = Field( + False, description="Whether there are more messages to fetch" + ) + + @router.post( "/batch", response_model=list[TaskMessage], @@ -118,6 +135,11 @@ async def list_messages( order_by: str | None = None, order_direction: str = "desc", ) -> list[TaskMessage]: + """ + List messages for a task with offset-based pagination. + + For cursor-based pagination with infinite scroll support, use /messages/paginated. + """ task_message_entities = await message_use_case.list_messages( task_id=task_id, limit=limit, @@ -132,6 +154,85 @@ async def list_messages( ] +@router.get( + "/paginated", + response_model=PaginatedMessagesResponse, +) +async def list_messages_paginated( + task_id: DAuthorizedQuery(AgentexResourceType.task, AuthorizedOperationType.read), + message_use_case: DMessageUseCase, + limit: int = 50, + cursor: str | None = None, + direction: Literal["older", "newer"] = "older", +) -> PaginatedMessagesResponse: + """ + List messages for a task with cursor-based pagination. + + This endpoint is designed for infinite scroll UIs where new messages may arrive + while paginating through older ones. + + Args: + task_id: The task ID to filter messages by + limit: Maximum number of messages to return (default: 50) + cursor: Opaque cursor string for pagination. Pass the `next_cursor` from + a previous response to get the next page. + direction: Pagination direction - "older" to get older messages (default), + "newer" to get newer messages. + + Returns: + PaginatedMessagesResponse with: + - data: List of messages (newest first when direction="older") + - next_cursor: Cursor for fetching the next page (null if no more pages) + - has_more: Whether there are more messages to fetch + + Example: + First request: GET /messages/paginated?task_id=xxx&limit=50 + Next page: GET /messages/paginated?task_id=xxx&limit=50&cursor= + """ + # Decode cursor if provided + before_id = None + after_id = None + if cursor: + try: + cursor_data = decode_cursor(cursor) + if direction == "older": + before_id = cursor_data.id + else: + after_id = cursor_data.id + except ValueError: + # Invalid cursor, ignore and return from start + pass + + # Fetch one extra to determine if there are more results + task_message_entities = await message_use_case.list_messages( + task_id=task_id, + limit=limit + 1, + page_number=1, + order_by=None, + order_direction="desc", + before_id=before_id, + after_id=after_id, + ) + + # Check if there are more results + has_more = len(task_message_entities) > limit + task_message_entities = task_message_entities[:limit] + + # Build next cursor from last message + next_cursor = None + if has_more and task_message_entities: + last_message = task_message_entities[-1] + next_cursor = encode_cursor(last_message.id, last_message.created_at) + + messages = [TaskMessage.model_validate(entity) for entity in task_message_entities] + + return PaginatedMessagesResponse( + data=messages, + next_cursor=next_cursor, + has_more=has_more, + ) + + @router.get( "/{message_id}", response_model=TaskMessage, diff --git a/agentex/src/domain/services/task_message_service.py b/agentex/src/domain/services/task_message_service.py index 202aa9e..17437f0 100644 --- a/agentex/src/domain/services/task_message_service.py +++ b/agentex/src/domain/services/task_message_service.py @@ -49,23 +49,43 @@ async def get_messages( page_number: int, order_by: str | None = None, order_direction: str = "desc", + before_id: str | None = None, + after_id: str | None = None, ) -> list[TaskMessageEntity]: """ - Get all messages for a specific task. + Get all messages for a specific task with optional cursor-based pagination. Args: task_id: The task ID - limit: Optional limit on the number of messages to return - order_by: Optional field name to order by (defaults to created_at) - order_direction: Optional direction to order by ("asc" or "desc", defaults to "desc") + limit: Maximum number of messages to return + page_number: Page number for offset-based pagination + order_by: Field name to order by (defaults to created_at) + order_direction: Direction to order by ("asc" or "desc", defaults to "desc") + before_id: Get messages created before this message ID (cursor pagination) + after_id: Get messages created after this message ID (cursor pagination) Returns: List of TaskMessageEntity objects for the task + + Note: + When using before_id or after_id, page_number is ignored. """ # Default to created_at descending (newest first) sort_field = order_by or "created_at" sort_direction = 1 if order_direction.lower() == "asc" else -1 + # If cursor pagination is requested, use cursor-based query + if before_id or after_id: + return await self.repository.find_by_field_with_cursor( + field_name="task_id", + field_value=task_id, + limit=limit, + sort_by={sort_field: sort_direction}, + before_id=before_id, + after_id=after_id, + ) + + # Otherwise use standard offset-based pagination return await self.repository.find_by_field( "task_id", task_id, diff --git a/agentex/src/domain/use_cases/messages_use_case.py b/agentex/src/domain/use_cases/messages_use_case.py index ce5eb7a..b5e8b12 100644 --- a/agentex/src/domain/use_cases/messages_use_case.py +++ b/agentex/src/domain/use_cases/messages_use_case.py @@ -111,18 +111,26 @@ async def list_messages( page_number: int, order_by: str | None = None, order_direction: str = "desc", + before_id: str | None = None, + after_id: str | None = None, ) -> list[TaskMessageEntity]: """ - Get all messages for a task. + Get all messages for a task with optional cursor-based pagination. Args: task_id: The task ID - limit: Optional limit on the number of messages to return - order_by: Optional field name to order by (defaults to created_at) - order_direction: Optional direction to order by ("asc" or "desc", defaults to "desc") + limit: Maximum number of messages to return + page_number: Page number for offset-based pagination + order_by: Field name to order by (defaults to created_at) + order_direction: Direction to order by ("asc" or "desc", defaults to "desc") + before_id: Get messages created before this message ID (cursor pagination) + after_id: Get messages created after this message ID (cursor pagination) Returns: List of TaskMessageEntity objects for the task + + Note: + When using before_id or after_id, page_number is ignored. """ return await self.task_message_service.get_messages( task_id=task_id, @@ -130,6 +138,8 @@ async def list_messages( page_number=page_number, order_by=order_by, order_direction=order_direction, + before_id=before_id, + after_id=after_id, ) diff --git a/agentex/src/utils/pagination.py b/agentex/src/utils/pagination.py new file mode 100644 index 0000000..247883d --- /dev/null +++ b/agentex/src/utils/pagination.py @@ -0,0 +1,78 @@ +""" +Cursor-based pagination utilities. + +Provides encode/decode functions for creating opaque cursor strings +that can be used for stable pagination through result sets. +""" + +import base64 +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from src.utils.logging import make_logger + +logger = make_logger(__name__) + + +class CursorData(BaseModel): + """Internal cursor structure - versioned for future compatibility.""" + + v: int = 1 # Version for backwards compatibility + id: str # Document ID + created_at: str # ISO format timestamp + + +def encode_cursor(id: str, created_at: datetime | None) -> str | None: + """ + Encode pagination position into an opaque cursor string. + + Args: + id: The document ID + created_at: The document's creation timestamp + + Returns: + Base64-encoded cursor string, or None if created_at is null + + Note: + Returns None if created_at is null since cursors require timestamps + for stable pagination ordering. + """ + if created_at is None: + return None + + cursor_data = CursorData( + id=id, + created_at=created_at.isoformat(), + ) + json_str = cursor_data.model_dump_json() + return base64.urlsafe_b64encode(json_str.encode()).decode() + + +def decode_cursor(cursor: str) -> CursorData: + """ + Decode cursor string back to pagination data. + + Args: + cursor: Base64-encoded cursor string + + Returns: + CursorData with id and created_at + + Raises: + ValueError: If cursor format is invalid + """ + try: + json_str = base64.urlsafe_b64decode(cursor.encode()).decode() + return CursorData.model_validate_json(json_str) + except Exception as e: + raise ValueError(f"Invalid cursor format: {e}") from e + + +class PaginatedResponse(BaseModel): + """Response wrapper with cursor pagination metadata.""" + + data: list[Any] + next_cursor: str | None = None + has_more: bool = False diff --git a/agentex/tests/integration/api/messages/test_messages_api.py b/agentex/tests/integration/api/messages/test_messages_api.py index 1606064..7f04fb6 100644 --- a/agentex/tests/integration/api/messages/test_messages_api.py +++ b/agentex/tests/integration/api/messages/test_messages_api.py @@ -121,13 +121,15 @@ async def test_create_message_success_and_retrieve( async def test_list_messages_returns_valid_structure_and_schema( self, isolated_client, test_message, test_task ): - """Test that list messages endpoint returns valid array structure with real data""" + """Test that list messages endpoint returns valid list response with real data""" # When - Request all messages for the task response = await isolated_client.get(f"/messages?task_id={test_task.id}") - # Then - Should succeed with valid list structure and schema + # Then - Should succeed with a list response assert response.status_code == 200 messages = response.json() + + # Original endpoint returns a list directly assert isinstance(messages, list) assert len(messages) >= 1 # Should have at least our test message @@ -252,39 +254,47 @@ async def test_get_message_non_existent_returns_404_or_null(self, isolated_clien resp_json = response.json() assert "Item with id" in resp_json["message"] - async def test_list_messages_pagination( + async def test_list_messages_paginated( self, isolated_client, test_pagination_messages, test_task ): - """Test GET /messages/ endpoint with pagination.""" - # Given - A message record exists - # (created by test_pagination_messages fixture) + """Test GET /messages/paginated endpoint with cursor-based pagination.""" + # Given - 60 message records exist (created by test_pagination_messages fixture) - # When - List all messages with pagination + # When - List messages with default limit using the paginated endpoint response = await isolated_client.get( - "/messages", params={"task_id": test_task.id} + "/messages/paginated", params={"task_id": test_task.id} ) assert response.status_code == 200 response_data = response.json() - # Default limit if none specified - assert len(response_data) == 50 - page_number = 1 + # Validate paginated response structure + assert "data" in response_data + assert "next_cursor" in response_data + assert "has_more" in response_data + + # Default limit is 50, we have 60 messages so has_more should be True + assert len(response_data["data"]) == 50 + assert response_data["has_more"] is True + assert response_data["next_cursor"] is not None + + # Test cursor-based pagination - collect all messages + cursor = None paginated_messages = [] while True: - response = await isolated_client.get( - "/messages", - params={ - "limit": 7, - "page_number": page_number, - "task_id": test_task.id, - }, - ) + params = {"task_id": test_task.id, "limit": 7} + if cursor: + params["cursor"] = cursor + + response = await isolated_client.get("/messages/paginated", params=params) assert response.status_code == 200 - messages_data = response.json() - paginated_messages.extend(messages_data) - if len(messages_data) < 1: + page_data = response.json() + + paginated_messages.extend(page_data["data"]) + + if not page_data["has_more"]: break - page_number += 1 + cursor = page_data["next_cursor"] + assert len(paginated_messages) == len(test_pagination_messages) assert {(d["id"], d["content"]["content"]) for d in paginated_messages} == { (d.id, d.content.content) for d in test_pagination_messages @@ -338,7 +348,7 @@ async def test_list_messages_with_order_by( }, ) - # Then - Should return messages in ascending order + # Then - Should return messages in ascending order (list response) assert response_asc.status_code == 200 messages_asc = response_asc.json() assert len(messages_asc) == 3 @@ -357,7 +367,7 @@ async def test_list_messages_with_order_by( }, ) - # Then - Should return messages in descending order + # Then - Should return messages in descending order (list response) assert response_desc.status_code == 200 messages_desc = response_desc.json() assert len(messages_desc) == 3 @@ -414,7 +424,7 @@ async def test_list_messages_order_by_defaults_to_desc( params={"task_id": task.id}, ) - # Then - Should return messages successfully + # Then - Should return messages successfully (list response) assert response.status_code == 200 messages = response.json() assert len(messages) == 3