diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 3fbe711bd..54a718e90 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -1691,7 +1691,7 @@ func TestNestedConnection(t *testing.T) { ) require.NoError(t, err) require.Equal(t, 1, len(rsp.Group.Users.Edges)) - require.Equal(t, "gaFp0wAAAAcAAAAI", rsp.Group.Users.Edges[0].Cursor) + require.Equal(t, "gaFpzwAAAAcAAAAI", rsp.Group.Users.Edges[0].Cursor) }) } @@ -2536,6 +2536,56 @@ func TestOrderByEdgeCount(t *testing.T) { } }) + t.Run("MultiOrderWithPagination", func(t *testing.T) { + var ( + // language=GraphQL + query = `query CategoryByTodosCount($first: Int, $after: Cursor) { + categories( + first: $first, + after: $after + orderBy: [{field: TODOS_COUNT, direction: DESC}], + ) { + edges { + cursor + node { + id + text + } + } + } + }` + rsp struct { + Categories struct { + Edges []struct { + Cursor string + Node struct { + ID string + Text string + } + } + } + } + ) + gqlc.MustPost( + query, + &rsp, + client.Var("first", 2), + client.Var("after", nil), + ) + require.Len(t, rsp.Categories.Edges, 2) + + // Do another query to get the next node after the first in our original query. + expectedNode := rsp.Categories.Edges[1].Node + gqlc.MustPost( + query, + &rsp, + client.Var("first", 1), + client.Var("after", rsp.Categories.Edges[0].Cursor), + ) + require.Len(t, rsp.Categories.Edges, 1) + require.Equal(t, expectedNode.ID, rsp.Categories.Edges[0].Node.ID) + }) + t.Run("NestedEdgeCountOrdering", func(t *testing.T) { var ( // language=GraphQL diff --git a/entgql/pagination.go b/entgql/pagination.go index 1cf6ccebc..a075cf2c6 100644 --- a/entgql/pagination.go +++ b/entgql/pagination.go @@ -146,7 +146,7 @@ func CursorsPredicate[T any](after, before *Cursor[T], idField, field string, di s.Where(sql.P(func(b *sql.Builder) { // The predicate function is executed on query generation time. column := s.C(field) - // If there is a non-ambiguis match, we use it. That is because + // If there is a non-ambiguous match, we use it. That is because // some order terms may append joined information to query selection. if matches := s.FindSelection(field); len(matches) == 1 { column = matches[0] @@ -218,16 +218,32 @@ func multiPredicate[T any](cursor *Cursor[T], opts *MultiCursorsOptions) (func(* return func(s *sql.Selector) { // Given the following terms: x DESC, y ASC, etc. The following predicate will be // generated: (x < x1 OR (x = x1 AND y > y1) OR (x = x1 AND y = y1 AND id > last)). + + // getColumnNameForField gets the name for the term and considers non-ambigous matching of + // terms that may be joined instead of a column on the table. + getColumnNameForField := func(field string) string { + // The predicate function is executed on query generation time. + column := s.C(field) + // If there is a non-ambiguous match, we use it. That is because + // some order terms may append joined information to query selection. + if matches := s.FindSelection(field); len(matches) == 1 { + column = matches[0] + } + return column + } + var or []*sql.Predicate for i := range opts.Fields { var ands []*sql.Predicate for j := 0; j < i; j++ { - ands = append(ands, sql.EQ(s.C(opts.Fields[j]), values[j])) + c := getColumnNameForField(opts.Fields[j]) + ands = append(ands, sql.EQ(c, values[j])) } + c := getColumnNameForField(opts.Fields[i]) if opts.Directions[i] == OrderDirectionAsc { - ands = append(ands, sql.GT(s.C(opts.Fields[i]), values[i])) + ands = append(ands, sql.GT(c, values[i])) } else { - ands = append(ands, sql.LT(s.C(opts.Fields[i]), values[i])) + ands = append(ands, sql.LT(c, values[i])) } or = append(or, sql.And(ands...)) }