From 6d0ed44fbf2fea5b5ee57d605572d26e0e8c9a4b Mon Sep 17 00:00:00 2001 From: doug-martin Date: Mon, 30 Sep 2019 05:08:13 -0500 Subject: [PATCH] v9.4.0 * [ADDED] Ability to scan into struct fields from multiple tables #160 --- HISTORY.md | 4 + database_example_test.go | 16 +- docs/selecting.md | 150 ++++++++++++++ exec/query_executor_test.go | 78 ++++++- exp/col.go | 7 +- exp/exp.go | 8 + exp/ident.go | 18 ++ exp/ident_test.go | 166 +++++++++++++++ internal/errors/error.go | 11 +- internal/tag/tags.go | 3 + internal/util/reflect.go | 89 ++++++-- internal/util/reflect_test.go | 365 +++++++++++++++++++++++++++++++++ select_dataset_example_test.go | 229 ++++++++++++++++++++- 13 files changed, 1092 insertions(+), 52 deletions(-) create mode 100644 exp/ident_test.go diff --git a/HISTORY.md b/HISTORY.md index f369d059..9d1f3060 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# v9.4.0 + +* [ADDED] Ability to scan into struct fields from multiple tables [#160](https://github.com/doug-martin/goqu/issues/160) + # v9.3.0 * [ADDED] Using Update, Insert, or Delete datasets in sub selects and CTEs [#164](https://github.com/doug-martin/goqu/issues/164) diff --git a/database_example_test.go b/database_example_test.go index 5a759b33..746ce198 100644 --- a/database_example_test.go +++ b/database_example_test.go @@ -109,13 +109,13 @@ func ExampleDatabase_Dialect() { func ExampleDatabase_Exec() { db := getDb() - _, err := db.Exec(`DROP TABLE "goqu_user"`) + _, err := db.Exec(`DROP TABLE "user_role"; DROP TABLE "goqu_user"`) if err != nil { - fmt.Println("Error occurred while dropping table", err.Error()) + fmt.Println("Error occurred while dropping tables", err.Error()) } - fmt.Println("Dropped table goqu_user") + fmt.Println("Dropped tables user_role and goqu_user") // Output: - // Dropped table goqu_user + // Dropped tables user_role and goqu_user } func ExampleDatabase_ExecContext() { @@ -123,13 +123,13 @@ func ExampleDatabase_ExecContext() { d := time.Now().Add(50 * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), d) defer cancel() - _, err := db.ExecContext(ctx, `DROP TABLE "goqu_user"`) + _, err := db.ExecContext(ctx, `DROP TABLE "user_role"; DROP TABLE "goqu_user"`) if err != nil { - fmt.Println("Error occurred while dropping table", err.Error()) + fmt.Println("Error occurred while dropping tables", err.Error()) } - fmt.Println("Dropped table goqu_user") + fmt.Println("Dropped tables user_role and goqu_user") // Output: - // Dropped table goqu_user + // Dropped tables user_role and goqu_user } func ExampleDatabase_From() { diff --git a/docs/selecting.md b/docs/selecting.md index d2b35d1a..30835603 100644 --- a/docs/selecting.md +++ b/docs/selecting.md @@ -842,6 +842,78 @@ if err := db.From("user").Select("first_name").ScanStructs(&users); err != nil{ fmt.Printf("\n%+v", users) ``` +`goqu` also supports scanning into multiple structs. In the example below we define a `Role` and `User` struct that could both be used individually to scan into. However, you can also create a new struct that adds both structs as fields that can be populated in a single query. + +**NOTE** When calling `ScanStructs` without a select already defined it will automatically only `SELECT` the columns found in the struct + + ```go +type Role struct { + Id uint64 `db:"id"` + UserID uint64 `db:"user_id"` + Name string `db:"name"` +} +type User struct { + Id uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` +} +type UserAndRole struct { + User User `db:"goqu_user"` // tag as the "goqu_user" table + Role Role `db:"user_role"` // tag as "user_role" table +} +db := getDb() + +ds := db. + From("goqu_user"). + Join(goqu.T("user_role"), goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))) +var users []UserAndRole + // Scan structs will auto build the +if err := ds.ScanStructs(&users); err != nil { + fmt.Println(err.Error()) + return +} +for _, u := range users { + fmt.Printf("\n%+v", u) +} +``` + +You can alternatively manually select the columns with the appropriate aliases using the `goqu.C` method to create the alias. + +```go +type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` +} +type User struct { + Id uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Role Role `db:"user_role"` // tag as "user_role" table +} +db := getDb() + +ds := db. + Select( + "goqu_user.id", + "goqu_user.first_name", + "goqu_user.last_name", + // alias the fully qualified identifier `C` is important here so it doesnt parse it + goqu.I("user_role.user_id").As(goqu.C("user_role.user_id")), + goqu.I("user_role.name").As(goqu.C("user_role.name")), + ). + From("goqu_user"). + Join(goqu.T("user_role"), goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))) + +var users []User +if err := ds.ScanStructs(&users); err != nil { + fmt.Println(err.Error()) + return +} +for _, u := range users { + fmt.Printf("\n%+v", u) +} +``` + **[`ScanStruct`](http://godoc.org/github.com/doug-martin/goqu#SelectDataset.ScanStruct)** @@ -869,6 +941,83 @@ if !found { } ``` +`goqu` also supports scanning into multiple structs. In the example below we define a `Role` and `User` struct that could both be used individually to scan into. However, you can also create a new struct that adds both structs as fields that can be populated in a single query. + +**NOTE** When calling `ScanStruct` without a select already defined it will automatically only `SELECT` the columns found in the struct + + ```go +type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` +} +type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` +} +type UserAndRole struct { + User User `db:"goqu_user"` // tag as the "goqu_user" table + Role Role `db:"user_role"` // tag as "user_role" table +} +db := getDb() +var userAndRole UserAndRole +ds := db. + From("goqu_user"). + Join(goqu.T("user_role"),goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))). + Where(goqu.C("first_name").Eq("Bob")) + +found, err := ds.ScanStruct(&userAndRole) +if err != nil{ + fmt.Println(err.Error()) + return +} +if !found { + fmt.Println("No user found") +} else { + fmt.Printf("\nFound user: %+v", user) +} +``` + +You can alternatively manually select the columns with the appropriate aliases using the `goqu.C` method to create the alias. + +```go +type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` +} +type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Role Role `db:"user_role"` // tag as "user_role" table +} +db := getDb() +var userAndRole UserAndRole +ds := db. + Select( + "goqu_user.id", + "goqu_user.first_name", + "goqu_user.last_name", + // alias the fully qualified identifier `C` is important here so it doesnt parse it + goqu.I("user_role.user_id").As(goqu.C("user_role.user_id")), + goqu.I("user_role.name").As(goqu.C("user_role.name")), + ). + From("goqu_user"). + Join(goqu.T("user_role"),goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))). + Where(goqu.C("first_name").Eq("Bob")) + +found, err := ds.ScanStruct(&userAndRole) +if err != nil{ + fmt.Println(err.Error()) + return +} +if !found { + fmt.Println("No user found") +} else { + fmt.Printf("\nFound user: %+v", user) +} +``` + **NOTE** Using the `goqu.SetColumnRenameFunction` function, you can change the function that's used to rename struct fields when struct tags aren't defined @@ -1037,3 +1186,4 @@ fmt.Printf("\nIds := %+v", ids) + diff --git a/exec/query_executor_test.go b/exec/query_executor_test.go index 816c5028..1e66e6ff 100644 --- a/exec/query_executor_test.go +++ b/exec/query_executor_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" "testing" "time" @@ -415,8 +416,8 @@ func (qes *queryExecutorSuite) TestScanStructs_withIgnoredEmbeddedPointerStruct( var composed []ComposedIgnoredPointerStruct qes.NoError(e.ScanStructs(&composed)) qes.Equal([]ComposedIgnoredPointerStruct{ - {StructWithTags: &StructWithTags{}, PhoneNumber: testPhone1, Age: testAge1}, - {StructWithTags: &StructWithTags{}, PhoneNumber: testPhone2, Age: testAge2}, + {PhoneNumber: testPhone1, Age: testAge1}, + {PhoneNumber: testPhone2, Age: testAge2}, }, composed) } @@ -944,6 +945,79 @@ func (qes *queryExecutorSuite) TestScanStruct() { }, noTag) } +func (qes *queryExecutorSuite) TestScanStruct_taggedStructs() { + type StructWithNoTags struct { + Address string + Name string + } + + type StructWithTags struct { + Address string `db:"address"` + Name string `db:"name"` + } + + type ComposedStruct struct { + StructWithTags + PhoneNumber string `db:"phone_number"` + Age int64 `db:"age"` + } + type ComposedWithPointerStruct struct { + *StructWithTags + PhoneNumber string `db:"phone_number"` + Age int64 `db:"age"` + } + + type StructWithTaggedStructs struct { + NoTags StructWithNoTags `db:"notags"` + Tags StructWithTags `db:"tags"` + Composed ComposedStruct `db:"composedstruct"` + ComposedPointer ComposedWithPointerStruct `db:"composedptrstruct"` + } + + db, mock, err := sqlmock.New() + qes.NoError(err) + + cols := []string{ + "notags.address", "notags.name", + "tags.address", "tags.name", + "composedstruct.address", "composedstruct.name", "composedstruct.phone_number", "composedstruct.age", + "composedptrstruct.address", "composedptrstruct.name", "composedptrstruct.phone_number", "composedptrstruct.age", + } + + q := `SELECT` + strings.Join(cols, ", ") + ` FROM "items"` + + mock.ExpectQuery(q). + WithArgs(). + WillReturnRows(sqlmock.NewRows(cols).AddRow( + testAddr1, testName1, + testAddr2, testName2, + testAddr1, testName1, testPhone1, testAge1, + testAddr2, testName2, testPhone2, testAge2, + )) + + e := newQueryExecutor(db, nil, q) + + var item StructWithTaggedStructs + found, err := e.ScanStruct(&item) + qes.NoError(err) + qes.True(found) + qes.Equal(StructWithTaggedStructs{ + NoTags: StructWithNoTags{Address: testAddr1, Name: testName1}, + Tags: StructWithTags{Address: testAddr2, Name: testName2}, + Composed: ComposedStruct{ + StructWithTags: StructWithTags{Address: testAddr1, Name: testName1}, + PhoneNumber: testPhone1, + Age: testAge1, + }, + ComposedPointer: ComposedWithPointerStruct{ + StructWithTags: &StructWithTags{Address: testAddr2, Name: testName2}, + PhoneNumber: testPhone2, + Age: testAge2, + }, + }, item) + +} + func (qes *queryExecutorSuite) TestScanVals() { db, mock, err := sqlmock.New() qes.NoError(err) diff --git a/exp/col.go b/exp/col.go index c7f99e8b..d3931c36 100644 --- a/exp/col.go +++ b/exp/col.go @@ -32,7 +32,12 @@ func NewColumnListExpression(vals ...interface{}) ColumnListExpression { } structCols := cm.Cols() for _, col := range structCols { - cols = append(cols, ParseIdentifier(col)) + i := ParseIdentifier(col) + var sc Expression = i + if i.IsQualified() { + sc = i.As(NewIdentifierExpression("", "", col)) + } + cols = append(cols, sc) } } else { panic(fmt.Sprintf("Cannot created expression from %+v", val)) diff --git a/exp/exp.go b/exp/exp.go index 8f0c6978..2246aacd 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -252,6 +252,14 @@ type ( Updateable Distinctable Castable + // returns true if this identifier has more more than on part (Schema, Table or Col) + // "schema" -> true //cant qualify anymore + // "schema.table" -> true + // "table" -> false + // "schema"."table"."col" -> true + // "table"."col" -> true + // "col" -> false + IsQualified() bool // Returns a new IdentifierExpression with the specified schema Schema(string) IdentifierExpression // Returns the current schema diff --git a/exp/ident.go b/exp/ident.go index d0b9a7af..a2882c51 100644 --- a/exp/ident.go +++ b/exp/ident.go @@ -33,6 +33,24 @@ func (i identifier) Clone() Expression { return i.clone() } +func (i identifier) IsQualified() bool { + schema, table, col := i.schema, i.table, i.col + switch c := col.(type) { + case string: + if c != "" { + return len(table) > 0 || len(schema) > 0 + } + default: + if c != nil { + return len(table) > 0 || len(schema) > 0 + } + } + if len(table) > 0 { + return len(schema) > 0 + } + return false +} + // Sets the table on the current identifier // I("col").Table("table") -> "table"."col" //postgres // I("col").Table("table") -> `table`.`col` //mysql diff --git a/exp/ident_test.go b/exp/ident_test.go new file mode 100644 index 00000000..52966a1d --- /dev/null +++ b/exp/ident_test.go @@ -0,0 +1,166 @@ +package exp + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type identifierExpressionSuite struct { + suite.Suite +} + +func TestIdentifierExpressionSuite(t *testing.T) { + suite.Run(t, new(identifierExpressionSuite)) +} + +func (ies *identifierExpressionSuite) TestParseIdentifier() { + cases := []struct { + ToParse string + Expected IdentifierExpression + }{ + {ToParse: "one", Expected: NewIdentifierExpression("", "", "one")}, + {ToParse: "one.two", Expected: NewIdentifierExpression("", "one", "two")}, + {ToParse: "one.two.three", Expected: NewIdentifierExpression("one", "two", "three")}, + } + for _, tc := range cases { + ies.Equal(tc.Expected, ParseIdentifier(tc.ToParse)) + } +} + +func (ies *identifierExpressionSuite) TestClone() { + cases := []struct { + Expected IdentifierExpression + }{ + {Expected: NewIdentifierExpression("", "", "one")}, + {Expected: NewIdentifierExpression("", "two", "one")}, + {Expected: NewIdentifierExpression("three", "two", "one")}, + } + for _, tc := range cases { + ies.Equal(tc.Expected, tc.Expected.Clone()) + } +} + +func (ies *identifierExpressionSuite) TestIsQualified() { + cases := []struct { + Ident IdentifierExpression + IsQualified bool + }{ + {Ident: NewIdentifierExpression("", "", "col"), IsQualified: false}, + {Ident: NewIdentifierExpression("", "table", ""), IsQualified: false}, + {Ident: NewIdentifierExpression("", "table", nil), IsQualified: false}, + {Ident: NewIdentifierExpression("schema", "", ""), IsQualified: false}, + {Ident: NewIdentifierExpression("schema", "", nil), IsQualified: false}, + {Ident: NewIdentifierExpression("", "table", NewLiteralExpression("*")), IsQualified: true}, + {Ident: NewIdentifierExpression("", "table", "col"), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "table", nil), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "table", NewLiteralExpression("*")), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "table", ""), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "table", "col"), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "", "col"), IsQualified: true}, + {Ident: NewIdentifierExpression("schema", "", NewLiteralExpression("*")), IsQualified: true}, + } + for _, tc := range cases { + ies.Equal(tc.IsQualified, tc.Ident.IsQualified(), "expected %s IsQualified to be %b", tc.Ident, tc.IsQualified) + } +} + +func (ies *identifierExpressionSuite) TestGetTable() { + cases := []struct { + Ident IdentifierExpression + Table string + }{ + {Ident: NewIdentifierExpression("", "", "col"), Table: ""}, + {Ident: NewIdentifierExpression("", "table", "col"), Table: "table"}, + {Ident: NewIdentifierExpression("schema", "", "col"), Table: ""}, + {Ident: NewIdentifierExpression("schema", "table", nil), Table: "table"}, + {Ident: NewIdentifierExpression("schema", "table", "col"), Table: "table"}, + } + for _, tc := range cases { + ies.Equal(tc.Table, tc.Ident.GetTable()) + } +} + +func (ies *identifierExpressionSuite) TestGetSchema() { + cases := []struct { + Ident IdentifierExpression + Schema string + }{ + {Ident: NewIdentifierExpression("", "", "col"), Schema: ""}, + {Ident: NewIdentifierExpression("", "table", "col"), Schema: ""}, + {Ident: NewIdentifierExpression("schema", "", "col"), Schema: "schema"}, + {Ident: NewIdentifierExpression("schema", "table", nil), Schema: "schema"}, + {Ident: NewIdentifierExpression("schema", "table", "col"), Schema: "schema"}, + } + for _, tc := range cases { + ies.Equal(tc.Schema, tc.Ident.GetSchema()) + } +} + +func (ies *identifierExpressionSuite) TestGetCol() { + cases := []struct { + Ident IdentifierExpression + Col interface{} + }{ + {Ident: NewIdentifierExpression("", "", "col"), Col: "col"}, + {Ident: NewIdentifierExpression("", "", "*"), Col: NewLiteralExpression("*")}, + {Ident: NewIdentifierExpression("", "table", "col"), Col: "col"}, + {Ident: NewIdentifierExpression("schema", "", "col"), Col: "col"}, + {Ident: NewIdentifierExpression("schema", "table", nil), Col: nil}, + {Ident: NewIdentifierExpression("schema", "table", "col"), Col: "col"}, + } + for _, tc := range cases { + ies.Equal(tc.Col, tc.Ident.GetCol()) + } +} + +func (ies *identifierExpressionSuite) TestExpression() { + i := NewIdentifierExpression("", "", "col") + ies.Equal(i, i.Expression()) +} + +func (ies *identifierExpressionSuite) TestAll() { + cases := []struct { + Ident IdentifierExpression + }{ + {Ident: NewIdentifierExpression("", "", "col")}, + {Ident: NewIdentifierExpression("", "table", "col")}, + {Ident: NewIdentifierExpression("schema", "table", "col")}, + {Ident: NewIdentifierExpression("", "", nil)}, + {Ident: NewIdentifierExpression("", "table", nil)}, + {Ident: NewIdentifierExpression("schema", "table", nil)}, + } + for _, tc := range cases { + ies.Equal( + NewIdentifierExpression(tc.Ident.GetSchema(), tc.Ident.GetTable(), NewLiteralExpression("*")), + tc.Ident.All(), + ) + } +} + +func (ies *identifierExpressionSuite) TestIsEmpty() { + cases := []struct { + Ident IdentifierExpression + IsEmpty bool + }{ + {Ident: NewIdentifierExpression("", "", ""), IsEmpty: true}, + {Ident: NewIdentifierExpression("", "", nil), IsEmpty: true}, + {Ident: NewIdentifierExpression("", "", "col"), IsEmpty: false}, + {Ident: NewIdentifierExpression("", "", NewLiteralExpression("*")), IsEmpty: false}, + {Ident: NewIdentifierExpression("", "table", ""), IsEmpty: false}, + {Ident: NewIdentifierExpression("", "table", nil), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "", ""), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "", nil), IsEmpty: false}, + {Ident: NewIdentifierExpression("", "table", NewLiteralExpression("*")), IsEmpty: false}, + {Ident: NewIdentifierExpression("", "table", "col"), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "table", nil), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "table", NewLiteralExpression("*")), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "table", ""), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "table", "col"), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "", "col"), IsEmpty: false}, + {Ident: NewIdentifierExpression("schema", "", NewLiteralExpression("*")), IsEmpty: false}, + } + for _, tc := range cases { + ies.Equal(tc.IsEmpty, tc.Ident.IsEmpty(), "expected %s IsEmpty to be %b", tc.Ident, tc.IsEmpty) + } +} diff --git a/internal/errors/error.go b/internal/errors/error.go index 6aa74333..d0cb8072 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -10,19 +10,10 @@ func New(message string, args ...interface{}) error { return Error{err: "goqu: " + fmt.Sprintf(message, args...)} } -func (e Error) Error() string { - return e.err -} - -type EncodeError struct { - error - err string -} - func NewEncodeError(t interface{}) error { return Error{err: "goqu_encode_error: " + fmt.Sprintf("Unable to encode value %+v", t)} } -func (e EncodeError) Error() string { +func (e Error) Error() string { return e.err } diff --git a/internal/tag/tags.go b/internal/tag/tags.go index bb61afc2..03652e9e 100644 --- a/internal/tag/tags.go +++ b/internal/tag/tags.go @@ -14,6 +14,9 @@ func New(tagName string, st reflect.StructTag) Options { } func (o Options) Values() []string { + if string(o) == "" { + return []string{} + } return strings.Split(string(o), ",") } diff --git a/internal/util/reflect.go b/internal/util/reflect.go index f7eaa29c..2e5ddef3 100644 --- a/internal/util/reflect.go +++ b/internal/util/reflect.go @@ -1,6 +1,7 @@ package util import ( + "database/sql" "reflect" "sort" "strings" @@ -28,6 +29,8 @@ const ( defaultIfEmptyTagName = "defaultifempty" ) +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + func IsUint(k reflect.Kind) bool { return (k == reflect.Uint) || (k == reflect.Uint8) || @@ -156,31 +159,42 @@ func SafeGetFieldByIndex(v reflect.Value, fieldIndex []int) (result reflect.Valu return reflect.ValueOf(nil), false } +func SafeSetFieldByIndex(v reflect.Value, fieldIndex []int, src interface{}) (result reflect.Value) { + v = reflect.Indirect(v) + switch len(fieldIndex) { + case 0: + return v + case 1: + f := v.FieldByIndex(fieldIndex) + srcVal := reflect.ValueOf(src) + f.Set(reflect.Indirect(srcVal)) + default: + f := v.Field(fieldIndex[0]) + switch f.Kind() { + case reflect.Ptr: + s := f + if f.IsNil() || !f.IsValid() { + s = reflect.New(f.Type().Elem()) + f.Set(s) + } + SafeSetFieldByIndex(reflect.Indirect(s), fieldIndex[1:], src) + case reflect.Struct: + SafeSetFieldByIndex(f, fieldIndex[1:], src) + } + } + return v +} + type rowData = map[string]interface{} // AssignStructVals will assign the data from rd to i. func AssignStructVals(i interface{}, rd rowData, cm ColumnMap) { val := reflect.Indirect(reflect.ValueOf(i)) - initEmbeddedPtr(val) for name, data := range cm { src, ok := rd[name] if ok { - f := val.FieldByIndex(data.FieldIndex) - srcVal := reflect.ValueOf(src) - f.Set(reflect.Indirect(srcVal)) - } - } -} - -func initEmbeddedPtr(value reflect.Value) { - for i := 0; i < value.NumField(); i++ { - v := value.Field(i) - kind := v.Kind() - t := value.Type().Field(i) - if t.Anonymous && kind == reflect.Ptr { - z := reflect.New(t.Type.Elem()) - v.Set(z) + SafeSetFieldByIndex(val, data.FieldIndex, src) } } } @@ -195,12 +209,12 @@ func GetColumnMap(i interface{}) (ColumnMap, error) { structMapCacheLock.Lock() defer structMapCacheLock.Unlock() if _, ok := structMapCache[t]; !ok { - structMapCache[t] = createColumnMap(t, []int{}) + structMapCache[t] = createColumnMap(t, []int{}, []string{}) } return structMapCache[t], nil } -func createColumnMap(t reflect.Type, fieldIndex []int) ColumnMap { +func createColumnMap(t reflect.Type, fieldIndex []int, prefixes []string) ColumnMap { cm, n := ColumnMap{}, t.NumField() var subColMaps []ColumnMap for i := 0; i < n; i++ { @@ -208,23 +222,40 @@ func createColumnMap(t reflect.Type, fieldIndex []int) ColumnMap { if f.Anonymous && (f.Type.Kind() == reflect.Struct || f.Type.Kind() == reflect.Ptr) { goquTag := tag.New("db", f.Tag) if !goquTag.Contains("-") { + subFieldIndexes := append(fieldIndex, f.Index...) + subPrefixes := append(prefixes, goquTag.Values()...) if f.Type.Kind() == reflect.Ptr { - subColMaps = append(subColMaps, createColumnMap(f.Type.Elem(), append(fieldIndex, f.Index...))) + subColMaps = append(subColMaps, createColumnMap(f.Type.Elem(), subFieldIndexes, subPrefixes)) } else { - subColMaps = append(subColMaps, createColumnMap(f.Type, append(fieldIndex, f.Index...))) + subColMaps = append(subColMaps, createColumnMap(f.Type, subFieldIndexes, subPrefixes)) } } } else if f.PkgPath == "" { - // if PkgPath is empty then it is an exported field dbTag := tag.New("db", f.Tag) + // if PkgPath is empty then it is an exported field var columnName string if dbTag.IsEmpty() { columnName = columnRenameFunction(f.Name) } else { columnName = dbTag.Values()[0] } - goquTag := tag.New("goqu", f.Tag) if !dbTag.Equals("-") { + if !implementsScanner(f.Type) { + subFieldIndexes := append(fieldIndex, f.Index...) + subPrefixes := append(prefixes, columnName) + var subCm ColumnMap + if f.Type.Kind() == reflect.Ptr { + subCm = createColumnMap(f.Type.Elem(), subFieldIndexes, subPrefixes) + } else { + subCm = createColumnMap(f.Type, subFieldIndexes, subPrefixes) + } + if len(subCm) != 0 { + subColMaps = append(subColMaps, subCm) + continue + } + } + goquTag := tag.New("goqu", f.Tag) + columnName = strings.Join(append(prefixes, columnName), ".") cm[columnName] = ColumnData{ ColumnName: columnName, ShouldInsert: !goquTag.Contains(skipInsertTagName), @@ -254,3 +285,17 @@ func (cm ColumnMap) Cols() []string { sort.Strings(structCols) return structCols } + +func implementsScanner(t reflect.Type) bool { + if IsPointer(t.Kind()) { + t = t.Elem() + } + if reflect.PtrTo(t).Implements(scannerType) { + return true + } + if !IsStruct(t.Kind()) { + return true + } + + return false +} diff --git a/internal/util/reflect_test.go b/internal/util/reflect_test.go index 5eabc1e0..8fc692b0 100644 --- a/internal/util/reflect_test.go +++ b/internal/util/reflect_test.go @@ -2,6 +2,7 @@ package util import ( "database/sql" + "fmt" "reflect" "strings" "sync" @@ -549,6 +550,126 @@ func (rt *reflectTest) TestAssignStructVals_withStructWithEmbeddedStructPointer( }) } +func (rt *reflectTest) TestAssignStructVals_withStructWithTaggedEmbeddedStruct() { + + type EmbeddedStruct struct { + Str string + } + type TestStruct struct { + EmbeddedStruct `db:"embedded"` + Int int64 + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + ns := &sql.NullString{String: "null_str1", Valid: true} + data := map[string]interface{}{ + "embedded.str": "string", + "int": int64(10), + "bool": true, + "valuer": &ns, + } + AssignStructVals(&ts, data, cm) + rt.Equal(ts, TestStruct{ + EmbeddedStruct: EmbeddedStruct{Str: "string"}, + Int: 10, + Bool: true, + Valuer: ns, + }) +} + +func (rt *reflectTest) TestAssignStructVals_withStructWithTaggedEmbeddedPointer() { + + type EmbeddedStruct struct { + Str string + } + type TestStruct struct { + *EmbeddedStruct `db:"embedded"` + Int int64 + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + ns := &sql.NullString{String: "null_str1", Valid: true} + data := map[string]interface{}{ + "embedded.str": "string", + "int": int64(10), + "bool": true, + "valuer": &ns, + } + AssignStructVals(&ts, data, cm) + rt.Equal(ts, TestStruct{ + EmbeddedStruct: &EmbeddedStruct{Str: "string"}, + Int: 10, + Bool: true, + Valuer: ns, + }) +} + +func (rt *reflectTest) TestAssignStructVals_withStructWithTaggedStructField() { + + type EmbeddedStruct struct { + Str string + } + type TestStruct struct { + Embedded EmbeddedStruct `db:"embedded"` + Int int64 + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + ns := &sql.NullString{String: "null_str1", Valid: true} + data := map[string]interface{}{ + "embedded.str": "string", + "int": int64(10), + "bool": true, + "valuer": &ns, + } + AssignStructVals(&ts, data, cm) + rt.Equal(ts, TestStruct{ + Embedded: EmbeddedStruct{Str: "string"}, + Int: 10, + Bool: true, + Valuer: ns, + }) +} + +func (rt *reflectTest) TestAssignStructVals_withStructWithTaggedPointerField() { + + type EmbeddedStruct struct { + Str string + } + type TestStruct struct { + Embedded *EmbeddedStruct `db:"embedded"` + Int int64 + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + ns := &sql.NullString{String: "null_str1", Valid: true} + data := map[string]interface{}{ + "embedded.str": "string", + "int": int64(10), + "bool": true, + "valuer": &ns, + } + AssignStructVals(&ts, data, cm) + rt.Equal(ts, TestStruct{ + Embedded: &EmbeddedStruct{Str: "string"}, + Int: 10, + Bool: true, + Valuer: ns, + }) +} + func (rt *reflectTest) TestGetColumnMap_withStruct() { type TestStruct struct { @@ -805,6 +926,197 @@ func (rt *reflectTest) TestGetColumnMap_withPrivateEmbeddedFields() { }, cm) } +func (rt *reflectTest) TestGetColumnMap_withEmbeddedTaggedStruct() { + + type TestEmbedded struct { + Bool bool + Valuer *sql.NullString + } + + type TestStruct struct { + TestEmbedded `db:"test_embedded"` + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + fmt.Println(cm) + rt.Equal(ColumnMap{ + "test_embedded.bool": { + ColumnName: "test_embedded.bool", + FieldIndex: []int{0, 0}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "test_embedded.valuer": { + ColumnName: "test_embedded.valuer", + FieldIndex: []int{0, 1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + "bool": { + ColumnName: "bool", + FieldIndex: []int{1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "valuer": { + ColumnName: "valuer", + FieldIndex: []int{2}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + }, cm) +} + +func (rt *reflectTest) TestGetColumnMap_withEmbeddedTaggedStructPointer() { + + type TestEmbedded struct { + Bool bool + Valuer *sql.NullString + } + + type TestStruct struct { + *TestEmbedded `db:"test_embedded"` + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + fmt.Println(cm) + rt.Equal(ColumnMap{ + "test_embedded.bool": { + ColumnName: "test_embedded.bool", + FieldIndex: []int{0, 0}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "test_embedded.valuer": { + ColumnName: "test_embedded.valuer", + FieldIndex: []int{0, 1}, + ShouldInsert: true, ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + "bool": { + ColumnName: "bool", + FieldIndex: []int{1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "valuer": { + ColumnName: "valuer", + FieldIndex: []int{2}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + }, cm) +} + +func (rt *reflectTest) TestGetColumnMap_withTaggedStructField() { + + type TestEmbedded struct { + Bool bool + Valuer *sql.NullString + } + + type TestStruct struct { + Embedded TestEmbedded `db:"test_embedded"` + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + fmt.Println(cm) + rt.Equal(ColumnMap{ + "test_embedded.bool": { + ColumnName: "test_embedded.bool", + FieldIndex: []int{0, 0}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "test_embedded.valuer": { + ColumnName: "test_embedded.valuer", + FieldIndex: []int{0, 1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + "bool": { + ColumnName: "bool", + FieldIndex: []int{1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "valuer": { + ColumnName: "valuer", + FieldIndex: []int{2}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + }, cm) +} + +func (rt *reflectTest) TestGetColumnMap_withTaggedStructPointerField() { + + type TestEmbedded struct { + Bool bool + Valuer *sql.NullString + } + + type TestStruct struct { + Embedded *TestEmbedded `db:"test_embedded"` + Bool bool + Valuer *sql.NullString + } + var ts TestStruct + cm, err := GetColumnMap(&ts) + rt.NoError(err) + fmt.Println(cm) + rt.Equal(ColumnMap{ + "test_embedded.bool": { + ColumnName: "test_embedded.bool", + FieldIndex: []int{0, 0}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "test_embedded.valuer": { + ColumnName: "test_embedded.valuer", + FieldIndex: []int{0, 1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + "bool": { + ColumnName: "bool", + FieldIndex: []int{1}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(true), + }, + "valuer": { + ColumnName: "valuer", + FieldIndex: []int{2}, + ShouldInsert: true, + ShouldUpdate: true, + GoType: reflect.TypeOf(&sql.NullString{}), + }, + }, cm) +} + func (rt *reflectTest) TestGetTypeInfo() { var a int64 var b []int64 @@ -886,6 +1198,59 @@ func (rt *reflectTest) TestSafeGetFieldByIndex() { rt.Equal(v, f) } +func (rt *reflectTest) TestSafeSetFieldByIndex() { + type TestEmbedded struct { + FieldA int + } + type TestEmbeddedPointerStruct struct { + *TestEmbedded + FieldB string + } + type TestEmbeddedStruct struct { + TestEmbedded + FieldB string + } + var teps TestEmbeddedPointerStruct + v := reflect.ValueOf(&teps) + f := SafeSetFieldByIndex(v, []int{}, nil) + rt.Equal(TestEmbeddedPointerStruct{}, f.Interface()) + + f = SafeSetFieldByIndex(v, []int{0, 0}, 1) + rt.Equal(TestEmbeddedPointerStruct{ + TestEmbedded: &TestEmbedded{FieldA: 1}, + }, f.Interface()) + + f = SafeSetFieldByIndex(v, []int{1}, "hello") + rt.Equal(TestEmbeddedPointerStruct{ + TestEmbedded: &TestEmbedded{FieldA: 1}, + FieldB: "hello", + }, f.Interface()) + rt.Equal(TestEmbeddedPointerStruct{ + TestEmbedded: &TestEmbedded{FieldA: 1}, + FieldB: "hello", + }, teps) + + var tes TestEmbeddedStruct + v = reflect.ValueOf(&tes) + f = SafeSetFieldByIndex(v, []int{}, nil) + rt.Equal(TestEmbeddedStruct{}, f.Interface()) + + f = SafeSetFieldByIndex(v, []int{0, 0}, 1) + rt.Equal(TestEmbeddedStruct{ + TestEmbedded: TestEmbedded{FieldA: 1}, + }, f.Interface()) + + f = SafeSetFieldByIndex(v, []int{1}, "hello") + rt.Equal(TestEmbeddedStruct{ + TestEmbedded: TestEmbedded{FieldA: 1}, + FieldB: "hello", + }, f.Interface()) + rt.Equal(TestEmbeddedStruct{ + TestEmbedded: TestEmbedded{FieldA: 1}, + FieldB: "hello", + }, tes) +} + func (rt *reflectTest) TestGetSliceElementType() { type MyStruct struct{} diff --git a/select_dataset_example_test.go b/select_dataset_example_test.go index 33d59ce3..416b2de0 100644 --- a/select_dataset_example_test.go +++ b/select_dataset_example_test.go @@ -6,24 +6,27 @@ import ( "fmt" "os" "regexp" + "time" "github.com/doug-martin/goqu/v9" "github.com/lib/pq" ) const schema = ` - DROP TABLE IF EXISTS "goqu_user"; - CREATE TABLE "goqu_user" ( - "id" SERIAL PRIMARY KEY NOT NULL, - "first_name" VARCHAR(45) NOT NULL, + DROP TABLE IF EXISTS "user_role"; + DROP TABLE IF EXISTS "goqu_user"; + CREATE TABLE "goqu_user" ( + "id" SERIAL PRIMARY KEY NOT NULL, + "first_name" VARCHAR(45) NOT NULL, "last_name" VARCHAR(45) NOT NULL, "created" TIMESTAMP NOT NULL DEFAULT now() ); - INSERT INTO "goqu_user" ("first_name", "last_name") VALUES - ('Bob', 'Yukon'), - ('Sally', 'Yukon'), - ('Vinita', 'Yukon'), - ('John', 'Doe') + CREATE TABLE "user_role" ( + "id" SERIAL PRIMARY KEY NOT NULL, + "user_id" BIGINT NOT NULL REFERENCES goqu_user(id) ON DELETE CASCADE, + "name" VARCHAR(45) NOT NULL, + "created" TIMESTAMP NOT NULL DEFAULT now() + ); ` const defaultDbURI = "postgres://postgres:@localhost:5435/goqupostgres?sslmode=disable" @@ -50,6 +53,41 @@ func getDb() *goqu.Database { if _, err := goquDb.Exec(schema); err != nil { panic(err) } + type goquUser struct { + ID int64 `db:"id" goqu:"skipinsert"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Created time.Time `db:"created" goqu:"skipupdate"` + } + + var users = []goquUser{ + {FirstName: "Bob", LastName: "Yukon"}, + {FirstName: "Sally", LastName: "Yukon"}, + {FirstName: "Vinita", LastName: "Yukon"}, + {FirstName: "John", LastName: "Doe"}, + } + var userIds []int64 + err := goquDb.Insert("goqu_user").Rows(users).Returning("id").Executor().ScanVals(&userIds) + if err != nil { + panic(err) + } + type userRole struct { + ID int64 `db:"id" goqu:"skipinsert"` + UserID int64 `db:"user_id"` + Name string `db:"name"` + Created time.Time `db:"created" goqu:"skipupdate"` + } + + var roles = []userRole{ + {UserID: userIds[0], Name: "Admin"}, + {UserID: userIds[1], Name: "Manager"}, + {UserID: userIds[2], Name: "Manager"}, + {UserID: userIds[3], Name: "User"}, + } + _, err = goquDb.Insert("user_role").Rows(roles).Executor().Exec() + if err != nil { + panic(err) + } return goquDb } @@ -1244,6 +1282,85 @@ func ExampleSelectDataset_ScanStructs_prepared() { // [{FirstName:Bob LastName:Yukon} {FirstName:Sally LastName:Yukon} {FirstName:Vinita LastName:Yukon}] } +// In this example we create a new struct that has two structs that represent two table +// the User and Role fields are tagged with the table name +func ExampleSelectDataset_ScanStructs_withJoinAutoSelect() { + type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` + } + type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + type UserAndRole struct { + User User `db:"goqu_user"` // tag as the "goqu_user" table + Role Role `db:"user_role"` // tag as "user_role" table + } + db := getDb() + + ds := db. + From("goqu_user"). + Join(goqu.T("user_role"), goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))) + var users []UserAndRole + // Scan structs will auto build the + if err := ds.ScanStructs(&users); err != nil { + fmt.Println(err.Error()) + return + } + for _, u := range users { + fmt.Printf("\n%+v", u) + } + // Output: + // {User:{ID:1 FirstName:Bob LastName:Yukon} Role:{UserID:1 Name:Admin}} + // {User:{ID:2 FirstName:Sally LastName:Yukon} Role:{UserID:2 Name:Manager}} + // {User:{ID:3 FirstName:Vinita LastName:Yukon} Role:{UserID:3 Name:Manager}} + // {User:{ID:4 FirstName:John LastName:Doe} Role:{UserID:4 Name:User}} +} + +// In this example we create a new struct that has the user properties as well as a nested +// Role struct from the join table +func ExampleSelectDataset_ScanStructs_withJoinManualSelect() { + type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` + } + type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Role Role `db:"user_role"` // tag as "user_role" table + } + db := getDb() + + ds := db. + Select( + "goqu_user.id", + "goqu_user.first_name", + "goqu_user.last_name", + // alias the fully qualified identifier `C` is important here so it doesnt parse it + goqu.I("user_role.user_id").As(goqu.C("user_role.user_id")), + goqu.I("user_role.name").As(goqu.C("user_role.name")), + ). + From("goqu_user"). + Join(goqu.T("user_role"), goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id")))) + var users []User + if err := ds.ScanStructs(&users); err != nil { + fmt.Println(err.Error()) + return + } + for _, u := range users { + fmt.Printf("\n%+v", u) + } + + // Output: + // {ID:1 FirstName:Bob LastName:Yukon Role:{UserID:1 Name:Admin}} + // {ID:2 FirstName:Sally LastName:Yukon Role:{UserID:2 Name:Manager}} + // {ID:3 FirstName:Vinita LastName:Yukon Role:{UserID:3 Name:Manager}} + // {ID:4 FirstName:John LastName:Doe Role:{UserID:4 Name:User}} +} + func ExampleSelectDataset_ScanStruct() { type User struct { FirstName string `db:"first_name"` @@ -1272,6 +1389,100 @@ func ExampleSelectDataset_ScanStruct() { // No user found for first_name Zeb } +// In this example we create a new struct that has two structs that represent two table +// the User and Role fields are tagged with the table name +func ExampleSelectDataset_ScanStruct_withJoinAutoSelect() { + type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` + } + type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + type UserAndRole struct { + User User `db:"goqu_user"` // tag as the "goqu_user" table + Role Role `db:"user_role"` // tag as "user_role" table + } + db := getDb() + findUserAndRoleByName := func(name string) { + var userAndRole UserAndRole + ds := db. + From("goqu_user"). + Join( + goqu.T("user_role"), + goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id"))), + ). + Where(goqu.C("first_name").Eq(name)) + found, err := ds.ScanStruct(&userAndRole) + switch { + case err != nil: + fmt.Println(err.Error()) + case !found: + fmt.Printf("No user found for first_name %s\n", name) + default: + fmt.Printf("Found user and role: %+v\n", userAndRole) + } + } + + findUserAndRoleByName("Bob") + findUserAndRoleByName("Zeb") + // Output: + // Found user and role: {User:{ID:1 FirstName:Bob LastName:Yukon} Role:{UserID:1 Name:Admin}} + // No user found for first_name Zeb +} + +// In this example we create a new struct that has the user properties as well as a nested +// Role struct from the join table +func ExampleSelectDataset_ScanStruct_withJoinManualSelect() { + type Role struct { + UserID uint64 `db:"user_id"` + Name string `db:"name"` + } + type User struct { + ID uint64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Role Role `db:"user_role"` // tag as "user_role" table + } + db := getDb() + findUserByName := func(name string) { + var userAndRole User + ds := db. + Select( + "goqu_user.id", + "goqu_user.first_name", + "goqu_user.last_name", + // alias the fully qualified identifier `C` is important here so it doesnt parse it + goqu.I("user_role.user_id").As(goqu.C("user_role.user_id")), + goqu.I("user_role.name").As(goqu.C("user_role.name")), + ). + From("goqu_user"). + Join( + goqu.T("user_role"), + goqu.On(goqu.I("goqu_user.id").Eq(goqu.I("user_role.user_id"))), + ). + Where(goqu.C("first_name").Eq(name)) + found, err := ds.ScanStruct(&userAndRole) + switch { + case err != nil: + fmt.Println(err.Error()) + case !found: + fmt.Printf("No user found for first_name %s\n", name) + default: + fmt.Printf("Found user and role: %+v\n", userAndRole) + } + } + + findUserByName("Bob") + findUserByName("Zeb") + + // Output: + // Found user and role: {ID:1 FirstName:Bob LastName:Yukon Role:{UserID:1 Name:Admin}} + // No user found for first_name Zeb +} + func ExampleSelectDataset_ScanVals() { var ids []int64 if err := getDb().From("goqu_user").Select("id").ScanVals(&ids); err != nil {