From 24731ec3549dfec7a8dc9bfa63214a111bdc7f1b Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 23 Apr 2023 11:40:59 +0300 Subject: [PATCH] entgql: satisfies fragments on Node/Nodes queries This fixes a long time bug that was found now in the new fields selection optimization --- entgql/internal/todo/ent/gql_collection.go | 30 +++++--- entgql/internal/todo/todo_test.go | 70 +++++++++++++++++++ entgql/internal/todofed/ent/gql_collection.go | 14 ++++ .../internal/todogotype/ent/gql_collection.go | 24 +++++-- .../internal/todopulid/ent/gql_collection.go | 24 +++++-- .../internal/todouuid/ent/gql_collection.go | 24 +++++-- entgql/template/collection.tmpl | 16 ++++- 7 files changed, 178 insertions(+), 24 deletions(-) diff --git a/entgql/internal/todo/ent/gql_collection.go b/entgql/internal/todo/ent/gql_collection.go index 49949cc1f..5a570a133 100644 --- a/entgql/internal/todo/ent/gql_collection.go +++ b/entgql/internal/todo/ent/gql_collection.go @@ -203,7 +203,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -291,7 +291,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Category")...); err != nil { return err } } @@ -604,7 +604,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "User")...); err != nil { return err } } @@ -870,7 +870,7 @@ func (pr *ProjectQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -1019,7 +1019,7 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -1245,7 +1245,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Group")...); err != nil { return err } } @@ -1333,7 +1333,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "User")...); err != nil { return err } } @@ -1417,7 +1417,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Friendship")...); err != nil { return err } } @@ -1599,3 +1599,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 657201676..8526385c2 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -2582,3 +2582,73 @@ func TestOrderByEdgeCount(t *testing.T) { require.Equal(t, rsp.Categories.Edges[1].Node.TodosCount, 3) }) } + +func TestSatisfiesFragments(t *testing.T) { + ctx := context.Background() + ec := enttest.Open( + t, dialect.SQLite, + fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()), + enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)), + ) + gqlc := client.New(handler.NewDefaultServer(gen.NewSchema(ec))) + cat := ec.Category.Create().SetText("cat").SetStatus(category.StatusEnabled).SaveX(ctx) + todos := ec.Todo.CreateBulk( + ec.Todo.Create().SetText("t1").SetStatus(todo.StatusPending).SetCategory(cat), + ec.Todo.Create().SetText("t2").SetStatus(todo.StatusInProgress).SetCategory(cat), + ec.Todo.Create().SetText("t3").SetStatus(todo.StatusCompleted).SetCategory(cat), + ).SaveX(ctx) + var ( + // language=GraphQL + query = `query CategoryTodo($id: ID!) { + category: node(id: $id) { + __typename + id + ... on Category { + text + ...CategoryTodos + } + } + } + + fragment CategoryTodos on Category { + todos (orderBy: {field: TEXT}) { + edges { + node { + id + ...TodoFields + } + } + } + } + + fragment TodoFields on Todo { + id + text + createdAt + } + ` + rsp struct { + Category struct { + TypeName string `json:"__typename"` + ID, Text string + Todos struct { + Edges []struct { + Node struct { + ID, Text, CreatedAt string + } + } + } + } + } + ) + gqlc.MustPost(query, &rsp, client.Var("id", cat.ID)) + require.Equal(t, strconv.Itoa(cat.ID), rsp.Category.ID) + require.Len(t, rsp.Category.Todos.Edges, 3) + for i := range todos { + require.Equal(t, strconv.Itoa(todos[i].ID), rsp.Category.Todos.Edges[i].Node.ID) + require.Equal(t, todos[i].Text, rsp.Category.Todos.Edges[i].Node.Text) + ts, err := todos[i].CreatedAt.MarshalText() + require.NoError(t, err) + require.Equal(t, string(ts), rsp.Category.Todos.Edges[i].Node.CreatedAt) + } +} diff --git a/entgql/internal/todofed/ent/gql_collection.go b/entgql/internal/todofed/ent/gql_collection.go index d2ddd2ba0..2cc10fff7 100644 --- a/entgql/internal/todofed/ent/gql_collection.go +++ b/entgql/internal/todofed/ent/gql_collection.go @@ -378,3 +378,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/entgql/internal/todogotype/ent/gql_collection.go b/entgql/internal/todogotype/ent/gql_collection.go index ff7f88081..6137e0e09 100644 --- a/entgql/internal/todogotype/ent/gql_collection.go +++ b/entgql/internal/todogotype/ent/gql_collection.go @@ -203,7 +203,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -291,7 +291,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Category")...); err != nil { return err } } @@ -604,7 +604,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "User")...); err != nil { return err } } @@ -832,7 +832,7 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -1058,7 +1058,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Group")...); err != nil { return err } } @@ -1254,3 +1254,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/entgql/internal/todopulid/ent/gql_collection.go b/entgql/internal/todopulid/ent/gql_collection.go index dca48c227..314c36f5d 100644 --- a/entgql/internal/todopulid/ent/gql_collection.go +++ b/entgql/internal/todopulid/ent/gql_collection.go @@ -202,7 +202,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -290,7 +290,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Category")...); err != nil { return err } } @@ -603,7 +603,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "User")...); err != nil { return err } } @@ -764,7 +764,7 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -990,7 +990,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Group")...); err != nil { return err } } @@ -1196,3 +1196,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/entgql/internal/todouuid/ent/gql_collection.go b/entgql/internal/todouuid/ent/gql_collection.go index 2e52313f2..11b5f253a 100644 --- a/entgql/internal/todouuid/ent/gql_collection.go +++ b/entgql/internal/todouuid/ent/gql_collection.go @@ -202,7 +202,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -290,7 +290,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, opCtx *graphql.Operati } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Category")...); err != nil { return err } } @@ -603,7 +603,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, opCtx *graphql.Operation } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "User")...); err != nil { return err } } @@ -764,7 +764,7 @@ func (t *TodoQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Todo")...); err != nil { return err } } @@ -990,7 +990,7 @@ func (u *UserQuery) collectField(ctx context.Context, opCtx *graphql.OperationCo } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Group")...); err != nil { return err } } @@ -1196,3 +1196,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/entgql/template/collection.tmpl b/entgql/template/collection.tmpl index b20962de0..b391ad8da 100644 --- a/entgql/template/collection.tmpl +++ b/entgql/template/collection.tmpl @@ -112,7 +112,7 @@ func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, opCtx *gr } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, opCtx, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "{{ $e.Type.Name }}")...); err != nil { return err } } @@ -357,6 +357,20 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} {{ end }} {{ define "gql_pagination/helper/load_total" }}