diff --git a/entgql/internal/todo/ent/gql_collection.go b/entgql/internal/todo/ent/gql_collection.go index d784c3076..ee6a0a0ae 100644 --- a/entgql/internal/todo/ent/gql_collection.go +++ b/entgql/internal/todo/ent/gql_collection.go @@ -41,13 +41,13 @@ func (bp *BillProductQuery) CollectFields(ctx context.Context, satisfies ...stri if fc == nil { return bp, nil } - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := bp.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return bp, nil } -func (bp *BillProductQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (bp *BillProductQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -118,13 +118,13 @@ func (c *CategoryQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return c, nil } - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := c.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return c, nil } -func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (c *CategoryQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -133,6 +133,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias @@ -204,19 +205,24 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.TodosColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.TodosColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } c.WithNamedTodos(alias, func(wq *TodoQuery) { *wq = *query }) + case "subCategories": var ( alias = field.Alias @@ -292,13 +298,17 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -415,13 +425,13 @@ func (f *FriendshipQuery) CollectFields(ctx context.Context, satisfies ...string if fc == nil { return f, nil } - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := f.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return f, nil } -func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (f *FriendshipQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -430,13 +440,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "user": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withUser = query @@ -444,13 +455,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera selectedFields = append(selectedFields, friendship.FieldUserID) fieldSeen[friendship.FieldUserID] = struct{}{} } + case "friend": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withFriend = query @@ -520,13 +532,13 @@ func (gr *GroupQuery) CollectFields(ctx context.Context, satisfies ...string) (* if fc == nil { return gr, nil } - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := gr.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return gr, nil } -func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (gr *GroupQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -535,6 +547,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "users": var ( alias = field.Alias @@ -610,13 +623,17 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -675,13 +692,13 @@ func (otm *OneToManyQuery) CollectFields(ctx context.Context, satisfies ...strin if fc == nil { return otm, nil } - if err := otm.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := otm.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return otm, nil } -func (otm *OneToManyQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (otm *OneToManyQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -690,13 +707,14 @@ func (otm *OneToManyQuery) collectField(ctx context.Context, opCtx *graphql.Oper ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&OneToManyClient{config: otm.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, onetomanyImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, onetomanyImplementors)...); err != nil { return err } otm.withParent = query @@ -704,13 +722,14 @@ func (otm *OneToManyQuery) collectField(ctx context.Context, opCtx *graphql.Oper selectedFields = append(selectedFields, onetomany.FieldParentID) fieldSeen[onetomany.FieldParentID] = struct{}{} } + case "children": var ( alias = field.Alias path = append(path, alias) query = (&OneToManyClient{config: otm.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, onetomanyImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, onetomanyImplementors)...); err != nil { return err } otm.WithNamedChildren(alias, func(wq *OneToManyQuery) { @@ -795,16 +814,17 @@ func (pr *ProjectQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return pr, nil } - if err := pr.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := pr.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return pr, nil } -func (pr *ProjectQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (pr *ProjectQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias @@ -876,13 +896,17 @@ func (pr *ProjectQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(project.TodosColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(project.TodosColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -929,13 +953,13 @@ func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) (*To if fc == nil { return t, nil } - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := t.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return t, nil } -func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TodoQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -944,16 +968,18 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.withParent = query + case "children": var ( alias = field.Alias @@ -1025,26 +1051,31 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(todo.ChildrenColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } t.WithNamedChildren(alias, func(wq *TodoQuery) { *wq = *query }) + case "category": var ( alias = field.Alias path = append(path, alias) query = (&CategoryClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } t.withCategory = query @@ -1167,13 +1198,13 @@ func (u *UserQuery) CollectFields(ctx context.Context, satisfies ...string) (*Us if fc == nil { return u, nil } - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := u.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return u, nil } -func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (u *UserQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -1182,6 +1213,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "groups": var ( alias = field.Alias @@ -1257,19 +1289,24 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } u.WithNamedGroups(alias, func(wq *GroupQuery) { *wq = *query }) + case "friends": var ( alias = field.Alias @@ -1345,19 +1382,24 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.FriendsPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.FriendsPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } u.WithNamedFriends(alias, func(wq *UserQuery) { *wq = *query }) + case "friendships": var ( alias = field.Alias @@ -1429,13 +1471,17 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.FriendshipsColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.FriendshipsColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -1531,13 +1577,13 @@ func (w *WorkspaceQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return w, nil } - if err := w.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := w.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return w, nil } -func (w *WorkspaceQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (w *WorkspaceQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -1644,29 +1690,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/internal/todo/ent/gql_node.go b/entgql/internal/todo/ent/gql_node.go index 173823bc2..6af6bc93e 100644 --- a/entgql/internal/todo/ent/gql_node.go +++ b/entgql/internal/todo/ent/gql_node.go @@ -155,111 +155,84 @@ func (c *Client) noder(ctx context.Context, table string, id int) (Noder, error) case billproduct.Table: query := c.BillProduct.Query(). Where(billproduct.ID(id)) - query, err := query.CollectFields(ctx, billproductImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, billproductImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case category.Table: query := c.Category.Query(). Where(category.ID(id)) - query, err := query.CollectFields(ctx, categoryImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, categoryImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case friendship.Table: query := c.Friendship.Query(). Where(friendship.ID(id)) - query, err := query.CollectFields(ctx, friendshipImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, friendshipImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case group.Table: query := c.Group.Query(). Where(group.ID(id)) - query, err := query.CollectFields(ctx, groupImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, groupImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case onetomany.Table: query := c.OneToMany.Query(). Where(onetomany.ID(id)) - query, err := query.CollectFields(ctx, onetomanyImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, onetomanyImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case project.Table: query := c.Project.Query(). Where(project.ID(id)) - query, err := query.CollectFields(ctx, projectImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, projectImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case todo.Table: query := c.Todo.Query(). Where(todo.ID(id)) - query, err := query.CollectFields(ctx, todoImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, todoImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case user.Table: query := c.User.Query(). Where(user.ID(id)) - query, err := query.CollectFields(ctx, userImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, userImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case workspace.Table: query := c.Workspace.Query(). Where(workspace.ID(id)) - query, err := query.CollectFields(ctx, workspaceImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, workspaceImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) } diff --git a/entgql/internal/todo/ent/gql_pagination.go b/entgql/internal/todo/ent/gql_pagination.go index 239849cd3..26d14cf89 100644 --- a/entgql/internal/todo/ent/gql_pagination.go +++ b/entgql/internal/todo/ent/gql_pagination.go @@ -310,11 +310,12 @@ func (bp *BillProductQuery) Paginate( if bp, err = pager.applyCursors(bp, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { bp.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := bp.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -607,11 +608,12 @@ func (c *CategoryQuery) Paginate( if c, err = pager.applyCursors(c, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { c.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := c.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -997,11 +999,12 @@ func (f *FriendshipQuery) Paginate( if f, err = pager.applyCursors(f, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { f.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := f.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1281,11 +1284,12 @@ func (gr *GroupQuery) Paginate( if gr, err = pager.applyCursors(gr, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { gr.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := gr.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1529,11 +1533,12 @@ func (otm *OneToManyQuery) Paginate( if otm, err = pager.applyCursors(otm, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { otm.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := otm.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := otm.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1824,11 +1829,12 @@ func (pr *ProjectQuery) Paginate( if pr, err = pager.applyCursors(pr, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { pr.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := pr.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := pr.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -2121,11 +2127,12 @@ func (t *TodoQuery) Paginate( if t, err = pager.applyCursors(t, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { t.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := t.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -2550,11 +2557,12 @@ func (u *UserQuery) Paginate( if u, err = pager.applyCursors(u, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { u.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := u.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -2853,11 +2861,12 @@ func (w *WorkspaceQuery) Paginate( if w, err = pager.applyCursors(w, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { w.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := w.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := w.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 3fbe711bd..3397fec95 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -2254,6 +2254,168 @@ func (r *queryRecorder) Query(ctx context.Context, query string, args, v interfa return r.Driver.Query(ctx, query, args, v) } +func TestReduceQueryComplexity(t *testing.T) { + ctx := context.Background() + drv, err := sql.Open(dialect.SQLite, fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name())) + require.NoError(t, err) + rec := &queryRecorder{Driver: drv} + ec := enttest.NewClient(t, + enttest.WithOptions(ent.Driver(rec)), + enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)), + ) + var ( + // language=GraphQL + query = `query Todo($id: ID!) { + node(id: $id) { + ... on Todo { + text + children (first: 10) { + edges { + node { + text + } + } + } + } + } + }` + gqlc = client.New(handler.NewDefaultServer(gen.NewSchema(ec))) + ) + t1 := ec.Todo.Create().SetText("t1").SetStatus(todo.StatusInProgress).SaveX(ctx) + rec.reset() + require.NoError(t, gqlc.Post(query, new(any), client.Var("id", t1.ID))) + require.Equal(t, []string{ + // Node mapping (cached). + "SELECT `type` FROM `ent_types` ORDER BY `id` ASC", + // Top-level todo. + "SELECT `todos`.`id`, `todos`.`text` FROM `todos` WHERE `todos`.`id` = ? LIMIT 2", + // Children todos (without CTE). + "SELECT `todos`.`id`, `todos`.`text`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`todo_children` IN (?) ORDER BY `todos`.`id` LIMIT 11", + }, rec.queries) + + // language=GraphQL + query = `query Todos($ids: [ID!]!) { + todos: nodes (ids: $ids) { + ... on Todo { + text + children (first: 10) { + edges { + node { + text + } + } + } + } + } + }` + rec.reset() + require.NoError(t, gqlc.Post(query, new(any), client.Var("ids", []int{t1.ID}))) + // A single ID is implemented by the `node` query. + require.Equal(t, []string{ + // Top-level todo. + "SELECT `todos`.`id`, `todos`.`text` FROM `todos` WHERE `todos`.`id` = ? LIMIT 2", + // Children todos (without CTE). + "SELECT `todos`.`id`, `todos`.`text`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`todo_children` IN (?) ORDER BY `todos`.`id` LIMIT 11", + }, rec.queries) + + rec.reset() + require.NoError(t, gqlc.Post(query, new(any), client.Var("ids", []int{t1.ID, t1.ID}))) + require.Equal(t, []string{ + // Top-level todo. + "SELECT `todos`.`id`, `todos`.`text` FROM `todos` WHERE `todos`.`id` IN (?, ?)", + // Children todos (with CTE). + "WITH `src_query` AS (SELECT `todos`.`id`, `todos`.`text`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`todo_children` IN (?)), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `todo_children` ORDER BY `id` ASC)) AS `row_number` FROM `src_query`) SELECT `id`, `text`, `project_todos`, `todo_children`, `todo_secret` FROM `limited_query` AS `todos` WHERE `todos`.`row_number` <= ?", + }, rec.queries) + + // Propagate uniqueness to one-child edges. + // language=GraphQL + query = `query Todo($id: ID!) { + node(id: $id) { + ... on Todo { + parent { + text + children (first: 5) { + edges { + node { + text + } + } + } + } + category { + text + todos (first: 10) { + edges { + node { + text + } + } + } + } + } + } + }` + ec.Todo.Create().SetText("t0").SetStatus(todo.StatusInProgress).AddChildren(t1).SaveX(ctx) + ec.Category.Create().AddTodos(t1).SetText("c0").SetStatus(category.StatusEnabled).SaveX(ctx) + rec.reset() + require.NoError(t, gqlc.Post(query, new(any), client.Var("id", t1.ID))) + require.Equal(t, []string{ + // Top-level todo. + "SELECT `todos`.`id`, `todos`.`category_id`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`id` = ? LIMIT 2", + // Parent todo. + "SELECT `todos`.`id`, `todos`.`text` FROM `todos` WHERE `todos`.`id` IN (?)", + // Parent children. + "SELECT `todos`.`id`, `todos`.`text`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`todo_children` IN (?) ORDER BY `todos`.`id` LIMIT 6", + // Category. + "SELECT `categories`.`id`, `categories`.`text` FROM `categories` WHERE `categories`.`id` IN (?)", + // Category todos. + "SELECT `todos`.`id`, `todos`.`text`, `todos`.`category_id`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`category_id` IN (?) ORDER BY `todos`.`id` LIMIT 11", + }, rec.queries) + + // Same as above, but with multiple IDs. + // language=GraphQL + query = `query Todo($id: ID!) { + nodes(ids: [$id, $id]) { + ... on Todo { + parent { + text + children (first: 5) { + edges { + node { + text + } + } + } + } + category { + text + todos (first: 10) { + edges { + node { + text + } + } + } + } + } + } + }` + rec.reset() + require.NoError(t, gqlc.Post(query, new(any), client.Var("id", t1.ID))) + require.Equal(t, []string{ + // Root nodes. + "SELECT `todos`.`id`, `todos`.`category_id`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`id` IN (?, ?)", + // Their parents (2 max). + "SELECT `todos`.`id`, `todos`.`text` FROM `todos` WHERE `todos`.`id` IN (?)", + // 5 children for each parent. + "WITH `src_query` AS (SELECT `todos`.`id`, `todos`.`text`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`todo_children` IN (?)), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `todo_children` ORDER BY `id` ASC)) AS `row_number` FROM `src_query`) SELECT `id`, `text`, `project_todos`, `todo_children`, `todo_secret` FROM `limited_query` AS `todos` WHERE `todos`.`row_number` <= ?", + // Category. + "SELECT `categories`.`id`, `categories`.`text` FROM `categories` WHERE `categories`.`id` IN (?)", + // 10 todos for each category. + "WITH `src_query` AS (SELECT `todos`.`id`, `todos`.`text`, `todos`.`category_id`, `todos`.`project_todos`, `todos`.`todo_children`, `todos`.`todo_secret` FROM `todos` WHERE `todos`.`category_id` IN (?)), `limited_query` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `category_id` ORDER BY `id` ASC)) AS `row_number` FROM `src_query`) SELECT `id`, `text`, `category_id`, `project_todos`, `todo_children`, `todo_secret` FROM `limited_query` AS `todos` WHERE `todos`.`row_number` <= ?", + }, rec.queries) +} + func TestFieldSelection(t *testing.T) { ctx := context.Background() drv, err := sql.Open(dialect.SQLite, fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name())) diff --git a/entgql/internal/todofed/ent/gql_collection.go b/entgql/internal/todofed/ent/gql_collection.go index 517a7df8d..b7035f8ad 100644 --- a/entgql/internal/todofed/ent/gql_collection.go +++ b/entgql/internal/todofed/ent/gql_collection.go @@ -22,7 +22,6 @@ import ( "entgo.io/contrib/entgql" "entgo.io/contrib/entgql/internal/todofed/ent/category" "entgo.io/contrib/entgql/internal/todofed/ent/todo" - "entgo.io/ent/dialect/sql" "github.com/99designs/gqlgen/graphql" ) @@ -32,13 +31,13 @@ func (c *CategoryQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return c, nil } - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := c.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return c, nil } -func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (c *CategoryQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -47,13 +46,14 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: c.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } c.WithNamedTodos(alias, func(wq *TodoQuery) { @@ -155,13 +155,13 @@ func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) (*To if fc == nil { return t, nil } - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := t.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return t, nil } -func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TodoQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -170,35 +170,38 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.withParent = query + case "children": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.WithNamedChildren(alias, func(wq *TodoQuery) { *wq = *query }) + case "category": var ( alias = field.Alias path = append(path, alias) query = (&CategoryClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } t.withCategory = query @@ -339,29 +342,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/internal/todofed/ent/gql_node.go b/entgql/internal/todofed/ent/gql_node.go index 718ac831e..6d4cdbde6 100644 --- a/entgql/internal/todofed/ent/gql_node.go +++ b/entgql/internal/todofed/ent/gql_node.go @@ -109,27 +109,21 @@ func (c *Client) noder(ctx context.Context, table string, id int) (Noder, error) case category.Table: query := c.Category.Query(). Where(category.ID(id)) - query, err := query.CollectFields(ctx, categoryImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, categoryImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case todo.Table: query := c.Todo.Query(). Where(todo.ID(id)) - query, err := query.CollectFields(ctx, todoImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, todoImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) } diff --git a/entgql/internal/todofed/ent/gql_pagination.go b/entgql/internal/todofed/ent/gql_pagination.go index b09981ef4..6db8d185c 100644 --- a/entgql/internal/todofed/ent/gql_pagination.go +++ b/entgql/internal/todofed/ent/gql_pagination.go @@ -303,11 +303,12 @@ func (c *CategoryQuery) Paginate( if c, err = pager.applyCursors(c, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { c.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := c.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -616,11 +617,12 @@ func (t *TodoQuery) Paginate( if t, err = pager.applyCursors(t, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { t.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := t.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } diff --git a/entgql/internal/todogotype/ent/gql_collection.go b/entgql/internal/todogotype/ent/gql_collection.go index ba3196641..edd4ba70a 100644 --- a/entgql/internal/todogotype/ent/gql_collection.go +++ b/entgql/internal/todogotype/ent/gql_collection.go @@ -40,13 +40,13 @@ func (bp *BillProductQuery) CollectFields(ctx context.Context, satisfies ...stri if fc == nil { return bp, nil } - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := bp.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return bp, nil } -func (bp *BillProductQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (bp *BillProductQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -117,13 +117,13 @@ func (c *CategoryQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return c, nil } - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := c.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return c, nil } -func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (c *CategoryQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -132,6 +132,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias @@ -203,19 +204,24 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.TodosColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.TodosColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } c.WithNamedTodos(alias, func(wq *TodoQuery) { *wq = *query }) + case "subCategories": var ( alias = field.Alias @@ -291,13 +297,17 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -414,13 +424,13 @@ func (f *FriendshipQuery) CollectFields(ctx context.Context, satisfies ...string if fc == nil { return f, nil } - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := f.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return f, nil } -func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (f *FriendshipQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -429,13 +439,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "user": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withUser = query @@ -443,13 +454,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera selectedFields = append(selectedFields, friendship.FieldUserID) fieldSeen[friendship.FieldUserID] = struct{}{} } + case "friend": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withFriend = query @@ -519,13 +531,13 @@ func (gr *GroupQuery) CollectFields(ctx context.Context, satisfies ...string) (* if fc == nil { return gr, nil } - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := gr.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return gr, nil } -func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (gr *GroupQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -534,6 +546,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "users": var ( alias = field.Alias @@ -609,13 +622,17 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -674,13 +691,13 @@ func (pe *PetQuery) CollectFields(ctx context.Context, satisfies ...string) (*Pe if fc == nil { return pe, nil } - if err := pe.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := pe.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return pe, nil } -func (pe *PetQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (pe *PetQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -741,13 +758,13 @@ func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) (*To if fc == nil { return t, nil } - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := t.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return t, nil } -func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TodoQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -756,16 +773,18 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.withParent = query + case "children": var ( alias = field.Alias @@ -837,26 +856,31 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(todo.ChildrenColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } t.WithNamedChildren(alias, func(wq *TodoQuery) { *wq = *query }) + case "category": var ( alias = field.Alias path = append(path, alias) query = (&CategoryClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } t.withCategory = query @@ -979,13 +1003,13 @@ func (u *UserQuery) CollectFields(ctx context.Context, satisfies ...string) (*Us if fc == nil { return u, nil } - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := u.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return u, nil } -func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (u *UserQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -994,6 +1018,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "groups": var ( alias = field.Alias @@ -1069,38 +1094,44 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } u.WithNamedGroups(alias, func(wq *GroupQuery) { *wq = *query }) + case "friends": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } u.WithNamedFriends(alias, func(wq *UserQuery) { *wq = *query }) + case "friendships": var ( alias = field.Alias path = append(path, alias) query = (&FriendshipClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { return err } u.WithNamedFriendships(alias, func(wq *FriendshipQuery) { @@ -1226,29 +1257,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/internal/todogotype/ent/gql_node.go b/entgql/internal/todogotype/ent/gql_node.go index 85aece025..f3f4f2f2a 100644 --- a/entgql/internal/todogotype/ent/gql_node.go +++ b/entgql/internal/todogotype/ent/gql_node.go @@ -151,15 +151,12 @@ func (c *Client) noder(ctx context.Context, table string, id string) (Noder, err case billproduct.Table: query := c.BillProduct.Query(). Where(billproduct.ID(id)) - query, err := query.CollectFields(ctx, billproductImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, billproductImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case category.Table: var uid bigintgql.BigInt if err := uid.UnmarshalGQL(id); err != nil { @@ -167,39 +164,30 @@ func (c *Client) noder(ctx context.Context, table string, id string) (Noder, err } query := c.Category.Query(). Where(category.ID(uid)) - query, err := query.CollectFields(ctx, categoryImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, categoryImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case friendship.Table: query := c.Friendship.Query(). Where(friendship.ID(id)) - query, err := query.CollectFields(ctx, friendshipImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, friendshipImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case group.Table: query := c.Group.Query(). Where(group.ID(id)) - query, err := query.CollectFields(ctx, groupImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, groupImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case pet.Table: var uid uintgql.Uint64 if err := uid.UnmarshalGQL(id); err != nil { @@ -207,39 +195,30 @@ func (c *Client) noder(ctx context.Context, table string, id string) (Noder, err } query := c.Pet.Query(). Where(pet.ID(uid)) - query, err := query.CollectFields(ctx, petImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, petImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case todo.Table: query := c.Todo.Query(). Where(todo.ID(id)) - query, err := query.CollectFields(ctx, todoImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, todoImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case user.Table: query := c.User.Query(). Where(user.ID(id)) - query, err := query.CollectFields(ctx, userImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, userImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) } diff --git a/entgql/internal/todogotype/ent/gql_pagination.go b/entgql/internal/todogotype/ent/gql_pagination.go index 6cd34d46d..bf65247a9 100644 --- a/entgql/internal/todogotype/ent/gql_pagination.go +++ b/entgql/internal/todogotype/ent/gql_pagination.go @@ -308,11 +308,12 @@ func (bp *BillProductQuery) Paginate( if bp, err = pager.applyCursors(bp, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { bp.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := bp.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -605,11 +606,12 @@ func (c *CategoryQuery) Paginate( if c, err = pager.applyCursors(c, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { c.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := c.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -977,11 +979,12 @@ func (f *FriendshipQuery) Paginate( if f, err = pager.applyCursors(f, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { f.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := f.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1225,11 +1228,12 @@ func (gr *GroupQuery) Paginate( if gr, err = pager.applyCursors(gr, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { gr.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := gr.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1473,11 +1477,12 @@ func (pe *PetQuery) Paginate( if pe, err = pager.applyCursors(pe, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { pe.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := pe.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := pe.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1770,11 +1775,12 @@ func (t *TodoQuery) Paginate( if t, err = pager.applyCursors(t, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { t.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := t.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -2199,11 +2205,12 @@ func (u *UserQuery) Paginate( if u, err = pager.applyCursors(u, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { u.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := u.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } diff --git a/entgql/internal/todopulid/ent/gql_collection.go b/entgql/internal/todopulid/ent/gql_collection.go index acaac35b2..0080b46da 100644 --- a/entgql/internal/todopulid/ent/gql_collection.go +++ b/entgql/internal/todopulid/ent/gql_collection.go @@ -39,13 +39,13 @@ func (bp *BillProductQuery) CollectFields(ctx context.Context, satisfies ...stri if fc == nil { return bp, nil } - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := bp.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return bp, nil } -func (bp *BillProductQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (bp *BillProductQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -116,13 +116,13 @@ func (c *CategoryQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return c, nil } - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := c.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return c, nil } -func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (c *CategoryQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -131,6 +131,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias @@ -202,19 +203,24 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.TodosColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.TodosColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } c.WithNamedTodos(alias, func(wq *TodoQuery) { *wq = *query }) + case "subCategories": var ( alias = field.Alias @@ -290,13 +296,17 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -413,13 +423,13 @@ func (f *FriendshipQuery) CollectFields(ctx context.Context, satisfies ...string if fc == nil { return f, nil } - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := f.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return f, nil } -func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (f *FriendshipQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -428,13 +438,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "user": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withUser = query @@ -442,13 +453,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera selectedFields = append(selectedFields, friendship.FieldUserID) fieldSeen[friendship.FieldUserID] = struct{}{} } + case "friend": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withFriend = query @@ -518,13 +530,13 @@ func (gr *GroupQuery) CollectFields(ctx context.Context, satisfies ...string) (* if fc == nil { return gr, nil } - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := gr.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return gr, nil } -func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (gr *GroupQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -533,6 +545,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "users": var ( alias = field.Alias @@ -608,13 +621,17 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -673,13 +690,13 @@ func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) (*To if fc == nil { return t, nil } - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := t.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return t, nil } -func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TodoQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -688,16 +705,18 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.withParent = query + case "children": var ( alias = field.Alias @@ -769,26 +788,31 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(todo.ChildrenColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } t.WithNamedChildren(alias, func(wq *TodoQuery) { *wq = *query }) + case "category": var ( alias = field.Alias path = append(path, alias) query = (&CategoryClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } t.withCategory = query @@ -911,13 +935,13 @@ func (u *UserQuery) CollectFields(ctx context.Context, satisfies ...string) (*Us if fc == nil { return u, nil } - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := u.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return u, nil } -func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (u *UserQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -926,6 +950,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "groups": var ( alias = field.Alias @@ -1001,38 +1026,44 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } u.WithNamedGroups(alias, func(wq *GroupQuery) { *wq = *query }) + case "friends": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } u.WithNamedFriends(alias, func(wq *UserQuery) { *wq = *query }) + case "friendships": var ( alias = field.Alias path = append(path, alias) query = (&FriendshipClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { return err } u.WithNamedFriendships(alias, func(wq *FriendshipQuery) { @@ -1173,29 +1204,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/internal/todopulid/ent/gql_node.go b/entgql/internal/todopulid/ent/gql_node.go index d1eff5d95..e2cc76613 100644 --- a/entgql/internal/todopulid/ent/gql_node.go +++ b/entgql/internal/todopulid/ent/gql_node.go @@ -135,15 +135,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.BillProduct.Query(). Where(billproduct.ID(uid)) - query, err := query.CollectFields(ctx, billproductImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, billproductImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case category.Table: var uid pulid.ID if err := uid.UnmarshalGQL(id); err != nil { @@ -151,15 +148,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.Category.Query(). Where(category.ID(uid)) - query, err := query.CollectFields(ctx, categoryImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, categoryImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case friendship.Table: var uid pulid.ID if err := uid.UnmarshalGQL(id); err != nil { @@ -167,15 +161,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.Friendship.Query(). Where(friendship.ID(uid)) - query, err := query.CollectFields(ctx, friendshipImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, friendshipImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case group.Table: var uid pulid.ID if err := uid.UnmarshalGQL(id); err != nil { @@ -183,15 +174,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.Group.Query(). Where(group.ID(uid)) - query, err := query.CollectFields(ctx, groupImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, groupImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case todo.Table: var uid pulid.ID if err := uid.UnmarshalGQL(id); err != nil { @@ -199,15 +187,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.Todo.Query(). Where(todo.ID(uid)) - query, err := query.CollectFields(ctx, todoImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, todoImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case user.Table: var uid pulid.ID if err := uid.UnmarshalGQL(id); err != nil { @@ -215,15 +200,12 @@ func (c *Client) noder(ctx context.Context, table string, id pulid.ID) (Noder, e } query := c.User.Query(). Where(user.ID(uid)) - query, err := query.CollectFields(ctx, userImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, userImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) } diff --git a/entgql/internal/todopulid/ent/gql_pagination.go b/entgql/internal/todopulid/ent/gql_pagination.go index d1141be50..9e134650a 100644 --- a/entgql/internal/todopulid/ent/gql_pagination.go +++ b/entgql/internal/todopulid/ent/gql_pagination.go @@ -308,11 +308,12 @@ func (bp *BillProductQuery) Paginate( if bp, err = pager.applyCursors(bp, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { bp.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := bp.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -605,11 +606,12 @@ func (c *CategoryQuery) Paginate( if c, err = pager.applyCursors(c, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { c.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := c.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -977,11 +979,12 @@ func (f *FriendshipQuery) Paginate( if f, err = pager.applyCursors(f, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { f.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := f.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1261,11 +1264,12 @@ func (gr *GroupQuery) Paginate( if gr, err = pager.applyCursors(gr, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { gr.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := gr.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1558,11 +1562,12 @@ func (t *TodoQuery) Paginate( if t, err = pager.applyCursors(t, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { t.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := t.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1987,11 +1992,12 @@ func (u *UserQuery) Paginate( if u, err = pager.applyCursors(u, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { u.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := u.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } diff --git a/entgql/internal/todouuid/ent/gql_collection.go b/entgql/internal/todouuid/ent/gql_collection.go index 39291f1a9..fd8260a92 100644 --- a/entgql/internal/todouuid/ent/gql_collection.go +++ b/entgql/internal/todouuid/ent/gql_collection.go @@ -39,13 +39,13 @@ func (bp *BillProductQuery) CollectFields(ctx context.Context, satisfies ...stri if fc == nil { return bp, nil } - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := bp.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return bp, nil } -func (bp *BillProductQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (bp *BillProductQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -116,13 +116,13 @@ func (c *CategoryQuery) CollectFields(ctx context.Context, satisfies ...string) if fc == nil { return c, nil } - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := c.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return c, nil } -func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (c *CategoryQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -131,6 +131,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "todos": var ( alias = field.Alias @@ -202,19 +203,24 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.TodosColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.TodosColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } c.WithNamedTodos(alias, func(wq *TodoQuery) { *wq = *query }) + case "subCategories": var ( alias = field.Alias @@ -290,13 +296,17 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(category.SubCategoriesPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -413,13 +423,13 @@ func (f *FriendshipQuery) CollectFields(ctx context.Context, satisfies ...string if fc == nil { return f, nil } - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := f.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return f, nil } -func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (f *FriendshipQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -428,13 +438,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "user": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withUser = query @@ -442,13 +453,14 @@ func (f *FriendshipQuery) collectField(ctx context.Context, opCtx *graphql.Opera selectedFields = append(selectedFields, friendship.FieldUserID) fieldSeen[friendship.FieldUserID] = struct{}{} } + case "friend": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: f.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } f.withFriend = query @@ -518,13 +530,13 @@ func (gr *GroupQuery) CollectFields(ctx context.Context, satisfies ...string) (* if fc == nil { return gr, nil } - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := gr.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return gr, nil } -func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (gr *GroupQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -533,6 +545,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "users": var ( alias = field.Alias @@ -608,13 +621,17 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(group.UsersPrimaryKey[1], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } @@ -673,13 +690,13 @@ func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) (*To if fc == nil { return t, nil } - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := t.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return t, nil } -func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TodoQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -688,16 +705,18 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "parent": var ( alias = field.Alias path = append(path, alias) query = (&TodoClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } t.withParent = query + case "children": var ( alias = field.Alias @@ -769,26 +788,31 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, todoImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(todo.ChildrenColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } t.WithNamedChildren(alias, func(wq *TodoQuery) { *wq = *query }) + case "category": var ( alias = field.Alias path = append(path, alias) query = (&CategoryClient{config: t.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { + if err := query.collectField(ctx, oneNode, opCtx, field, path, mayAddCondition(satisfies, categoryImplementors)...); err != nil { return err } t.withCategory = query @@ -911,13 +935,13 @@ func (u *UserQuery) CollectFields(ctx context.Context, satisfies ...string) (*Us if fc == nil { return u, nil } - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := u.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return u, nil } -func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func (u *UserQuery) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) var ( unknownSeen bool @@ -926,6 +950,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo ) for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { + case "groups": var ( alias = field.Alias @@ -1001,38 +1026,44 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, *field, path, mayAddCondition(satisfies, groupImplementors)...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + modify := entgql.LimitPerRow(user.GroupsPrimaryKey[0], limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } u.WithNamedGroups(alias, func(wq *GroupQuery) { *wq = *query }) + case "friends": var ( alias = field.Alias path = append(path, alias) query = (&UserClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, userImplementors)...); err != nil { return err } u.WithNamedFriends(alias, func(wq *UserQuery) { *wq = *query }) + case "friendships": var ( alias = field.Alias path = append(path, alias) query = (&FriendshipClient{config: u.config}).Query() ) - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { + if err := query.collectField(ctx, false, opCtx, field, path, mayAddCondition(satisfies, friendshipImplementors)...); err != nil { return err } u.WithNamedFriendships(alias, func(wq *FriendshipQuery) { @@ -1173,29 +1204,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/internal/todouuid/ent/gql_node.go b/entgql/internal/todouuid/ent/gql_node.go index fad087443..e7badafe5 100644 --- a/entgql/internal/todouuid/ent/gql_node.go +++ b/entgql/internal/todouuid/ent/gql_node.go @@ -131,75 +131,57 @@ func (c *Client) noder(ctx context.Context, table string, id uuid.UUID) (Noder, case billproduct.Table: query := c.BillProduct.Query(). Where(billproduct.ID(id)) - query, err := query.CollectFields(ctx, billproductImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, billproductImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case category.Table: query := c.Category.Query(). Where(category.ID(id)) - query, err := query.CollectFields(ctx, categoryImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, categoryImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case friendship.Table: query := c.Friendship.Query(). Where(friendship.ID(id)) - query, err := query.CollectFields(ctx, friendshipImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, friendshipImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case group.Table: query := c.Group.Query(). Where(group.ID(id)) - query, err := query.CollectFields(ctx, groupImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, groupImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case todo.Table: query := c.Todo.Query(). Where(todo.ID(id)) - query, err := query.CollectFields(ctx, todoImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, todoImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) case user.Table: query := c.User.Query(). Where(user.ID(id)) - query, err := query.CollectFields(ctx, userImplementors...) - if err != nil { - return nil, err - } - n, err := query.Only(ctx) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, userImplementors...); err != nil { + return nil, err + } } - return n, nil + return query.Only(ctx) default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) } diff --git a/entgql/internal/todouuid/ent/gql_pagination.go b/entgql/internal/todouuid/ent/gql_pagination.go index eb2b956e1..599cc7b41 100644 --- a/entgql/internal/todouuid/ent/gql_pagination.go +++ b/entgql/internal/todouuid/ent/gql_pagination.go @@ -308,11 +308,12 @@ func (bp *BillProductQuery) Paginate( if bp, err = pager.applyCursors(bp, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { bp.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := bp.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := bp.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -605,11 +606,12 @@ func (c *CategoryQuery) Paginate( if c, err = pager.applyCursors(c, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { c.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := c.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := c.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -977,11 +979,12 @@ func (f *FriendshipQuery) Paginate( if f, err = pager.applyCursors(f, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { f.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := f.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := f.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1261,11 +1264,12 @@ func (gr *GroupQuery) Paginate( if gr, err = pager.applyCursors(gr, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { gr.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := gr.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := gr.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1558,11 +1562,12 @@ func (t *TodoQuery) Paginate( if t, err = pager.applyCursors(t, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { t.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := t.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := t.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } @@ -1987,11 +1992,12 @@ func (u *UserQuery) Paginate( if u, err = pager.applyCursors(u, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { u.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := u.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := u.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } } diff --git a/entgql/pagination.go b/entgql/pagination.go index 1cf6ccebc..6e536f42a 100644 --- a/entgql/pagination.go +++ b/entgql/pagination.go @@ -234,3 +234,29 @@ func multiPredicate[T any](cursor *Cursor[T], opts *MultiCursorsOptions) (func(* s.Where(sql.Or(or...)) }, nil } + +// LimitPerRow returns a query modifier that limits the number of (edges) rows returned +// by the given partition. This helper function is used mainly by the paginated API to +// override the default Limit behavior for limit returned per node and not limit for all query. +func LimitPerRow(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { + return func(s *sql.Selector) { + d := sql.Dialect(s.Dialect()) + s.SetDistinct(false) + with := d.With("src_query"). + As(s.Clone()). + With("limited_query"). + As( + d.Select("*"). + AppendSelectExprAs( + sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), + "row_number", + ). + From(d.Table("src_query")), + ) + t := d.Table("limited_query").As(s.TableName()) + *s = *d.Select(s.UnqualifiedColumns()...). + From(t). + Where(sql.LTE(t.C("row_number"), limit)). + Prefix(with) + } +} diff --git a/entgql/template/collection.tmpl b/entgql/template/collection.tmpl index a6e7cc440..f117223ca 100644 --- a/entgql/template/collection.tmpl +++ b/entgql/template/collection.tmpl @@ -32,13 +32,13 @@ func ({{ $receiver }} *{{ $query }}) CollectFields(ctx context.Context, satisfie if fc == nil { return {{ $receiver }}, nil } - if err := {{ $receiver }}.collectField(ctx, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { + if err := {{ $receiver }}.collectField(ctx, false, graphql.GetOperationContext(ctx), fc.Field, nil, satisfies...); err != nil { return nil, err } return {{ $receiver }}, nil } -func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { +func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, oneNode bool, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) {{- $fields := filterFields $node.Fields (skipMode "type") }} {{- $collects := fieldCollections (filterEdges $node.Edges (skipMode "type")) }} @@ -59,6 +59,8 @@ func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, opCtx *gr switch field.Name { {{- range $i, $fc := $collects }} {{- $e := $fc.Edge }} + {{- /* If the edge is unique, we inherit the cardinality of the parent. */}} + {{ $oneNode := "false" }}{{- if $e.Unique }}{{ $oneNode = "oneNode" }}{{ end }} case {{ range $i, $value := $fc.Mapping }}{{ if $i }}, {{ end }}"{{ $value }}"{{ end }}: var ( alias = field.Alias @@ -112,23 +114,28 @@ func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, opCtx *gr } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, {{ nodeImplementorsVar $e.Type }})...); err != nil { + if err := query.collectField(ctx, {{ $oneNode }}, opCtx, *field, path, mayAddCondition(satisfies, {{ nodeImplementorsVar $e.Type }})...); err != nil { return err } } if limit := paginateLimit(args.first, args.last); limit > 0 { - {{- $fk := print $node.Package "." $fc.Edge.ColumnConstant }} - {{- if $e.M2M }} - {{- $i := 0 }}{{ if $e.IsInverse }}{{ $i = 1 }}{{ end }} - {{- $fk = print $node.Package "." $e.PKConstant "[" $i "]" }} - {{- end }} - modify := limitRows({{ $fk }}, limit, pager.orderExpr(query)) - query.modifiers = append(query.modifiers, modify) + {{- /* Limit per row is not required, as there is only node returned by the top query. */}} + if oneNode { + pager.applyOrder(query.Limit(limit)) + } else { + {{- $fk := print $node.Package "." $fc.Edge.ColumnConstant }} + {{- if $e.M2M }} + {{- $i := 0 }}{{ if $e.IsInverse }}{{ $i = 1 }}{{ end }} + {{- $fk = print $node.Package "." $e.PKConstant "[" $i "]" }} + {{- end }} + modify := entgql.LimitPerRow({{ $fk }}, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } } else { query = pager.applyOrder(query) } {{- else }} - if err := query.collectField(ctx, opCtx, field, path, mayAddCondition(satisfies, {{ nodeImplementorsVar $e.Type }})...); err != nil { + if err := query.collectField(ctx, {{ $oneNode }}, opCtx, field, path, mayAddCondition(satisfies, {{ nodeImplementorsVar $e.Type }})...); err != nil { return err } {{- end }} @@ -318,29 +325,6 @@ func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map return args } -func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sql.Selector) { - return func(s *sql.Selector) { - d := sql.Dialect(s.Dialect()) - s.SetDistinct(false) - with := d.With("src_query"). - As(s.Clone()). - With("limited_query"). - As( - d.Select("*"). - AppendSelectExprAs( - sql.RowNumber().PartitionBy(partitionBy).OrderExpr(orderBy...), - "row_number", - ). - From(d.Table("src_query")), - ) - t := d.Table("limited_query").As(s.TableName()) - *s = *d.Select(s.UnqualifiedColumns()...). - From(t). - Where(sql.LTE(t.C("row_number"), limit)). - Prefix(with) - } -} - // mayAddCondition appends another type condition to the satisfies list // if it does not exist in the list. func mayAddCondition(satisfies []string, typeCond []string) []string { diff --git a/entgql/template/node.tmpl b/entgql/template/node.tmpl index 530ab2ea9..1cc67765f 100644 --- a/entgql/template/node.tmpl +++ b/entgql/template/node.tmpl @@ -141,16 +141,13 @@ func (c *Client) noder(ctx context.Context, table string, id {{ $idType }}) (Nod query := c.{{ $n.Name }}.Query(). Where({{ $n.Package }}.ID({{ if $unmarshalID }}u{{ end }}id)) {{- if hasTemplate "gql_collection" }} - query, err := query.CollectFields(ctx, {{ nodeImplementorsVar $n }}...) - if err != nil { - return nil, err + if fc := graphql.GetFieldContext(ctx); fc != nil { + if err := query.collectField(ctx, true, graphql.GetOperationContext(ctx), fc.Field, nil, {{ nodeImplementorsVar $n }}...); err != nil { + return nil, err + } } {{- end }} - n, err := query.Only(ctx) - if err != nil { - return nil, err - } - return n, nil + return query.Only(ctx) {{- end }} default: return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) diff --git a/entgql/template/pagination.tmpl b/entgql/template/pagination.tmpl index ee66cb41c..5ddfa8fbd 100644 --- a/entgql/template/pagination.tmpl +++ b/entgql/template/pagination.tmpl @@ -648,11 +648,12 @@ func ({{ $r }} *{{ $name }}) ToEdge(order *{{ $order }}) *{{ $edge }} { if {{ $r }}, err = pager.applyCursors({{ $r }}, after, before); err != nil { return nil, err } - if limit := paginateLimit(first, last); limit != 0 { + limit := paginateLimit(first, last) + if limit != 0 { {{ $r }}.Limit(limit) } if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := {{ $r }}.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { + if err := {{ $r }}.collectField(ctx, limit == 1, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { return nil, err } }