Skip to content

Commit 22fc42e

Browse files
jmchiltonclaude
andcommitted
Migrate WES pagination to unified keyset-based approach
Replace offset-based and security.encode_id() pagination with generalized keyset abstraction supporting both single-ID and composite keysets. New abstraction (lib/galaxy/model/keyset_token_pagination.py): - KeysetToken Protocol: polymorphic via to_values()/from_values() - SingleKeysetToken: single ID keyset for runs pagination - KeysetPagination: base64+JSON encoder/decoder - TaskKeysetToken (wes.py): composite (step_order, job_index) keyset Changes to WES pagination: - list_runs: SingleKeysetToken with WHERE id < last_id (unchanged query) - get_run_tasks: TaskKeysetToken with tuple comparison WHERE (step_order, job_index) > (last_step, last_idx) - Removed offset-based _encode_page_token/_decode_page_token - Removed _encode_keyset_token/_decode_keyset_token Benefits: - Cursor-stable pagination under concurrent changes - Efficient composite keyset filtering on UNION queries - Consistent token encoding across all WES endpoints - Protocol-based design - no type checking needed Tests: 9 unit tests for keyset pagination abstraction (all passing) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 63ac76a commit 22fc42e

File tree

3 files changed

+269
-79
lines changed

3 files changed

+269
-79
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Keyset-based pagination support for cursor-stable pagination."""
2+
3+
import base64
4+
import json
5+
from dataclasses import dataclass
6+
from typing import (
7+
Optional,
8+
Protocol,
9+
Type,
10+
TypeVar,
11+
)
12+
13+
from galaxy import exceptions
14+
15+
16+
class KeysetToken(Protocol):
17+
"""Protocol for keyset tokens that can be encoded/decoded.
18+
19+
Implementations must provide:
20+
- to_values(): Convert token to normalized list of values for encoding
21+
- from_values(): Reconstruct token from decoded values (classmethod)
22+
"""
23+
24+
def to_values(self) -> list:
25+
"""Convert token to normalized list of values for encoding.
26+
27+
Returns:
28+
List of values to be JSON-encoded
29+
"""
30+
...
31+
32+
@classmethod
33+
def from_values(cls, values: list) -> "KeysetToken":
34+
"""Reconstruct token from decoded values.
35+
36+
Args:
37+
values: List of values from JSON decoding
38+
39+
Returns:
40+
Token instance
41+
"""
42+
...
43+
44+
45+
@dataclass
46+
class SingleKeysetToken:
47+
"""Single ID column keyset token.
48+
49+
Used for pagination on a single numeric ID column (e.g., database IDs).
50+
"""
51+
52+
last_id: int
53+
54+
def to_values(self) -> list:
55+
"""Convert to normalized values."""
56+
return [self.last_id]
57+
58+
@classmethod
59+
def from_values(cls, values: list) -> "SingleKeysetToken":
60+
"""Reconstruct from decoded values."""
61+
if len(values) < 1:
62+
raise ValueError("SingleKeysetToken requires at least 1 value")
63+
return cls(last_id=values[0])
64+
65+
66+
T = TypeVar("T", bound=KeysetToken)
67+
68+
69+
class KeysetPagination:
70+
"""Keyset pagination encoder/decoder using Protocol.
71+
72+
Encodes tokens to opaque base64 strings, works with any KeysetToken
73+
implementation via Protocol duck typing.
74+
"""
75+
76+
def encode_token(self, token: KeysetToken) -> str:
77+
"""Encode keyset token to opaque base64 string.
78+
79+
Works with any KeysetToken implementation via Protocol.
80+
81+
Args:
82+
token: Token implementing KeysetToken protocol
83+
84+
Returns:
85+
Base64-encoded token string
86+
"""
87+
values = token.to_values()
88+
payload = json.dumps(values)
89+
return base64.b64encode(payload.encode()).decode()
90+
91+
def decode_token(
92+
self,
93+
encoded: Optional[str],
94+
token_class: Type[T],
95+
) -> Optional[T]:
96+
"""Decode token using provided token class.
97+
98+
Args:
99+
encoded: Base64-encoded token string
100+
token_class: Token class with from_values() classmethod
101+
102+
Returns:
103+
Decoded token instance or None if encoded is None
104+
105+
Raises:
106+
MessageException: If token is invalid
107+
"""
108+
if not encoded:
109+
return None
110+
111+
try:
112+
payload = base64.b64decode(encoded.encode()).decode()
113+
values = json.loads(payload)
114+
return token_class.from_values(values)
115+
except (ValueError, TypeError, json.JSONDecodeError) as e:
116+
raise exceptions.MessageException(f"Invalid page_token: {str(e)}")

lib/galaxy/webapps/galaxy/services/wes.py

Lines changed: 61 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import base64
44
import json
55
import logging
6+
from dataclasses import dataclass
67
from typing import (
78
Any,
89
List,
@@ -17,6 +18,7 @@
1718
from sqlalchemy import (
1819
literal,
1920
select,
21+
tuple_,
2022
union_all,
2123
)
2224
from sqlalchemy.orm import joinedload
@@ -38,6 +40,10 @@
3840
WorkflowInvocation,
3941
WorkflowInvocationStep,
4042
)
43+
from galaxy.model.keyset_token_pagination import (
44+
KeysetPagination,
45+
SingleKeysetToken,
46+
)
4147
from galaxy.schema.wes import (
4248
DefaultWorkflowEngineParameter,
4349
RunId,
@@ -73,41 +79,28 @@
7379
}
7480

7581
WES_TO_GALAXY_STATE = {v: k for k, v in GALAXY_TO_WES_STATE.items()}
76-
PAGINATION_KEYSET_TOKEN_ENCODE_KEY = "pag_tok"
77-
7882

79-
def _encode_page_token(offset: int) -> str:
80-
"""Encode an offset as a base64 page token.
8183

82-
Args:
83-
offset: The offset (row number) for pagination
84+
@dataclass
85+
class TaskKeysetToken:
86+
"""Composite keyset token for task pagination (step_order, job_index).
8487
85-
Returns:
86-
Base64-encoded page token
88+
Used to identify position in task list for cursor-based pagination.
8789
"""
88-
return base64.b64encode(str(offset).encode()).decode()
8990

91+
step_order: int
92+
job_index: int
9093

91-
def _encode_keyset_token(security: IdEncodingHelper, last_id: int) -> str:
92-
"""Encode last seen ID as keyset page token."""
93-
return security.encode_id(last_id, kind=PAGINATION_KEYSET_TOKEN_ENCODE_KEY)
94+
def to_values(self) -> list:
95+
"""Convert token to normalized list of values for encoding."""
96+
return [self.step_order, self.job_index]
9497

95-
96-
def _decode_keyset_token(security: IdEncodingHelper, page_token: Optional[str]) -> Optional[int]:
97-
"""Decode keyset page token to last seen ID.
98-
99-
Returns None if no token, raises on invalid token.
100-
"""
101-
if not page_token:
102-
return None
103-
104-
try:
105-
last_id = security.decode_id(page_token, kind=PAGINATION_KEYSET_TOKEN_ENCODE_KEY)
106-
if last_id < 0:
107-
raise ValueError("ID cannot be negative")
108-
return last_id
109-
except (ValueError, TypeError) as e:
110-
raise exceptions.MessageException(f"Invalid page_token: {str(e)}")
98+
@classmethod
99+
def from_values(cls, values: list) -> "TaskKeysetToken":
100+
"""Reconstruct token from decoded values."""
101+
if len(values) < 2:
102+
raise ValueError("TaskKeysetToken requires at least 2 values")
103+
return cls(step_order=values[0], job_index=values[1])
111104

112105

113106
def _parse_gxworkflow_uri(workflow_url: str) -> tuple[str, bool]:
@@ -159,31 +152,6 @@ def _parse_gxworkflow_uri(workflow_url: str) -> tuple[str, bool]:
159152
raise exceptions.MessageException(f"Error parsing gxworkflow:// URI: {str(e)}")
160153

161154

162-
def _decode_page_token(page_token: Optional[str]) -> int:
163-
"""Decode a base64 page token to an offset.
164-
165-
Args:
166-
page_token: The base64-encoded page token
167-
168-
Returns:
169-
The offset (row number) for pagination
170-
171-
Raises:
172-
exceptions.MessageException: If token is invalid
173-
"""
174-
if not page_token:
175-
return 0
176-
177-
try:
178-
offset_str = base64.b64decode(page_token.encode()).decode()
179-
offset = int(offset_str)
180-
if offset < 0:
181-
raise ValueError("Offset cannot be negative")
182-
return offset
183-
except (ValueError, TypeError) as e:
184-
raise exceptions.MessageException(f"Invalid page_token: {str(e)}")
185-
186-
187155
def _load_workflow_content(
188156
trans: ProvidesUserContext,
189157
workflow_attachment: Optional[UploadFile],
@@ -384,6 +352,7 @@ def __init__(
384352
self._workflows_service = workflows_service
385353
self._config = config
386354
self._security = security
355+
self._keyset_pagination = KeysetPagination()
387356

388357
def service_info(self, trans: ProvidesUserContext, request_url: str) -> ServiceInfo:
389358
"""Return WES service information.
@@ -628,7 +597,8 @@ def list_runs(
628597
RunListResponse with paginated list of runs and next_page_token if more results exist
629598
"""
630599
# Decode keyset token to get last seen ID
631-
last_id = _decode_keyset_token(self._security, page_token)
600+
token = self._keyset_pagination.decode_token(page_token, token_class=SingleKeysetToken)
601+
last_id = token.last_id if token else None
632602

633603
# Build query with keyset filtering
634604
query = trans.sa_session.query(WorkflowInvocation).join(History).where(History.user_id == trans.user.id)
@@ -650,7 +620,8 @@ def list_runs(
650620
next_page_token = None
651621
if has_more and invocations:
652622
last_invocation = invocations[page_size - 1]
653-
next_page_token = _encode_keyset_token(self._security, last_invocation.id)
623+
token = SingleKeysetToken(last_id=last_invocation.id)
624+
next_page_token = self._keyset_pagination.encode_token(token)
654625

655626
return RunListResponse(runs=runs, next_page_token=next_page_token)
656627

@@ -777,33 +748,42 @@ def _get_paginated_task_rows(
777748
self,
778749
trans: ProvidesUserContext,
779750
invocation_id: int,
780-
offset: int,
751+
last_token: Optional[TaskKeysetToken],
781752
limit: int,
782753
) -> List[dict]:
783-
"""Fetch paginated task rows from database.
754+
"""Fetch paginated task rows using composite keyset pagination.
755+
756+
Uses (step_order, job_index) as composite keyset for cursor-based pagination.
784757
785758
Returns list of dicts with keys: step_id, step_order, task_type, job_id, job_index
786759
"""
787760
# Build UNION subquery
788761
task_rows_subquery = self._build_task_rows_query(invocation_id).subquery()
789762

790763
# Apply ordering and pagination
791-
stmt = (
792-
select(
793-
task_rows_subquery.c.step_id,
794-
task_rows_subquery.c.step_order,
795-
task_rows_subquery.c.task_type,
796-
task_rows_subquery.c.job_id,
797-
task_rows_subquery.c.job_index,
798-
)
799-
.order_by(
800-
task_rows_subquery.c.step_order,
801-
task_rows_subquery.c.job_index,
802-
)
803-
.offset(offset)
804-
.limit(limit + 1) # Fetch one extra to detect more results
764+
stmt = select(
765+
task_rows_subquery.c.step_id,
766+
task_rows_subquery.c.step_order,
767+
task_rows_subquery.c.task_type,
768+
task_rows_subquery.c.job_id,
769+
task_rows_subquery.c.job_index,
770+
).order_by(
771+
task_rows_subquery.c.step_order,
772+
task_rows_subquery.c.job_index,
805773
)
806774

775+
# Apply composite keyset filter if we have a cursor
776+
if last_token is not None:
777+
stmt = stmt.where(
778+
tuple_(
779+
task_rows_subquery.c.step_order,
780+
task_rows_subquery.c.job_index,
781+
)
782+
> tuple_(last_token.step_order, last_token.job_index)
783+
)
784+
785+
stmt = stmt.limit(limit + 1) # Fetch one extra to detect more results
786+
807787
result = trans.sa_session.execute(stmt)
808788
return [dict(row._mapping) for row in result]
809789

@@ -900,28 +880,28 @@ def get_run_tasks(
900880
) -> TaskListResponse:
901881
"""Get paginated list of tasks for a workflow run.
902882
903-
Uses database-level pagination via UNION query to avoid loading
904-
all steps/jobs into memory.
883+
Uses composite keyset pagination via UNION query to avoid loading
884+
all steps/jobs into memory and for cursor-based stability.
905885
906886
Args:
907887
trans: Galaxy transaction/context
908888
run_id: The WES run ID (Galaxy invocation ID)
909889
page_size: Number of tasks per page (default 10, max 100)
910-
page_token: Token for pagination (base64-encoded offset)
890+
page_token: Token for pagination (composite keyset: step_order, job_index)
911891
912892
Returns:
913893
TaskListResponse with paginated tasks
914894
"""
915895
invocation = self._get_invocation(trans, run_id)
916896

917-
# Decode page token to offset
918-
offset = _decode_page_token(page_token)
897+
# Decode composite keyset token
898+
token = self._keyset_pagination.decode_token(page_token, token_class=TaskKeysetToken)
919899

920900
# Fetch paginated task rows (+1 to detect more results)
921901
task_rows = self._get_paginated_task_rows(
922902
trans,
923903
invocation.id,
924-
offset,
904+
token,
925905
page_size,
926906
)
927907

@@ -938,8 +918,10 @@ def get_run_tasks(
938918

939919
# Generate next page token
940920
next_page_token = None
941-
if has_more:
942-
next_page_token = _encode_page_token(offset + page_size)
921+
if has_more and task_rows:
922+
last_row = task_rows[-1]
923+
token = TaskKeysetToken(step_order=last_row["step_order"], job_index=last_row["job_index"])
924+
next_page_token = self._keyset_pagination.encode_token(token)
943925

944926
return TaskListResponse(
945927
task_logs=task_logs if task_logs else None,

0 commit comments

Comments
 (0)