From 104d38e36545176914f59408a5e1ef7ee12200ab Mon Sep 17 00:00:00 2001 From: Jacob Brewer Date: Sat, 26 Oct 2024 19:21:56 +0100 Subject: [PATCH] feat(inserter): Allowing fields to be ignored on the inserter package (#60) updating the inserter package --- inserter/batch.go | 27 +++++++++++++ inserter/batch_opts.go | 20 +++++++++- inserter/sql.go | 72 +++++++++++++++++++++++++++++----- inserter/sql_test.go | 87 ++++++++++++++++++++++++++++++++++++++++-- loader.go | 25 ------------ patch.go | 26 +++++++++++++ sql.go | 5 +-- 7 files changed, 220 insertions(+), 42 deletions(-) diff --git a/inserter/batch.go b/inserter/batch.go index 6f2a55b..3a22047 100644 --- a/inserter/batch.go +++ b/inserter/batch.go @@ -3,6 +3,8 @@ package inserter import ( "database/sql" "errors" + + "github.com/jacobbrewer1/patcher" ) var ( @@ -34,6 +36,31 @@ type SQLBatch struct { // table is the table name to use in the SQL statement table string + + // ignoreFields is a list of fields to ignore when patching + ignoreFields []string + + // ignoreFieldsFunc is a function that determines whether a field should be ignored + // + // This func should return true is the field is to be ignored + ignoreFieldsFunc patcher.IgnoreFieldsFunc +} + +// newBatchDefaults returns a new SQLBatch with default values +func newBatchDefaults(opts ...BatchOpt) *SQLBatch { + b := &SQLBatch{ + fields: nil, + args: nil, + db: nil, + tagName: patcher.DefaultDbTagName, + table: "", + } + + for _, opt := range opts { + opt(b) + } + + return b } func (b *SQLBatch) Fields() []string { diff --git a/inserter/batch_opts.go b/inserter/batch_opts.go index 4a7e2f5..039b17c 100644 --- a/inserter/batch_opts.go +++ b/inserter/batch_opts.go @@ -1,6 +1,10 @@ package inserter -import "database/sql" +import ( + "database/sql" + + "github.com/jacobbrewer1/patcher" +) type BatchOpt func(*SQLBatch) @@ -24,3 +28,17 @@ func WithDB(db *sql.DB) BatchOpt { b.db = db } } + +// WithIgnoreFields sets the fields to ignore when patching +func WithIgnoreFields(fields ...string) BatchOpt { + return func(b *SQLBatch) { + b.ignoreFields = fields + } +} + +// WithIgnoreFieldsFunc sets the function that determines whether a field should be ignored +func WithIgnoreFieldsFunc(f patcher.IgnoreFieldsFunc) BatchOpt { + return func(b *SQLBatch) { + b.ignoreFieldsFunc = f + } +} diff --git a/inserter/sql.go b/inserter/sql.go index b9fc8d4..659a9a2 100644 --- a/inserter/sql.go +++ b/inserter/sql.go @@ -4,17 +4,15 @@ import ( "database/sql" "fmt" "reflect" + "slices" "strings" -) -const ( - // defaultTagName is the default tag name to look for in the struct - defaultTagName = "db" + "github.com/jacobbrewer1/patcher" ) func NewBatch(resources []any, opts ...BatchOpt) *SQLBatch { - b := new(SQLBatch) - b.tagName = defaultTagName + b := newBatchDefaults(opts...) + for _, opt := range opts { opt(b) } @@ -49,20 +47,34 @@ func (b *SQLBatch) genBatch(resources []any) { for i := 0; i < t.NumField(); i++ { f := t.Field(i) tag := f.Tag.Get(b.tagName) - if tag == "-" { + if tag == patcher.TagOptSkip { continue } + // Skip unexported fields if !f.IsExported() { continue } + // Skip fields that are to be ignored + if b.checkSkipField(f) { + continue + } + + patcherOptsTag := f.Tag.Get(patcher.TagOptsName) + if patcherOptsTag != "" { + patcherOpts := strings.Split(patcherOptsTag, patcher.TagOptSeparator) + if slices.Contains(patcherOpts, patcher.TagOptSkip) { + continue + } + } + // if no tag is set, use the field name if tag == "" { - tag = strings.ToLower(f.Name) + tag = f.Name } - b.args = append(b.args, v.Field(i).Interface()) + b.args = append(b.args, b.getFieldValue(v.Field(i), f)) // if the field is not unique, skip it if _, ok := uniqueFields[tag]; ok { @@ -76,6 +88,17 @@ func (b *SQLBatch) genBatch(resources []any) { } } +func (b *SQLBatch) getFieldValue(v reflect.Value, f reflect.StructField) any { + if f.Type.Kind() == reflect.Ptr { + if v.IsNil() { + return nil + } + return v.Elem().Interface() + } + + return v.Interface() +} + func (b *SQLBatch) GenerateSQL() (string, []any, error) { if err := b.validateSQLGen(); err != nil { return "", nil, err @@ -116,3 +139,34 @@ func (b *SQLBatch) Perform() (sql.Result, error) { return b.db.Exec(sqlStr, args...) } + +func (b *SQLBatch) checkSkipField(field reflect.StructField) bool { + // The ignore fields tag takes precedence over the ignore fields list + if b.checkSkipTag(field) { + return true + } + + return b.ignoredFieldsCheck(field) +} + +func (b *SQLBatch) checkSkipTag(field reflect.StructField) bool { + val, ok := field.Tag.Lookup(patcher.TagOptsName) + if !ok { + return false + } + + tags := strings.Split(val, patcher.TagOptSeparator) + return slices.Contains(tags, patcher.TagOptSkip) +} + +func (b *SQLBatch) ignoredFieldsCheck(field reflect.StructField) bool { + return b.checkIgnoredFields(strings.ToLower(field.Name)) || b.checkIgnoreFunc(field) +} + +func (b *SQLBatch) checkIgnoreFunc(field reflect.StructField) bool { + return b.ignoreFieldsFunc != nil && b.ignoreFieldsFunc(field) +} + +func (b *SQLBatch) checkIgnoredFields(field string) bool { + return len(b.ignoreFields) > 0 && slices.Contains(b.ignoreFields, strings.ToLower(field)) +} diff --git a/inserter/sql_test.go b/inserter/sql_test.go index 5b85023..79032e6 100644 --- a/inserter/sql_test.go +++ b/inserter/sql_test.go @@ -1,8 +1,11 @@ package inserter import ( + "reflect" "testing" + "github.com/jacobbrewer1/patcher" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -219,7 +222,7 @@ func (s *generateSQLSuite) TestGenerateSQL_noDbTag() { sql, args, err := NewBatch(resources, WithTable("temp")).GenerateSQL() s.Require().NoError(err) - s.Require().Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) + s.Require().Equal("INSERT INTO temp (ID, Name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) s.Require().Len(args, 10) } @@ -357,9 +360,85 @@ func (s *generateSQLSuite) TestGenerateSQL_Success_WithPointedFields() { sql, args, err := NewBatch(resources, WithTable("temp"), WithTagName("db")).GenerateSQL() s.Require().NoError(err) - s.Require().Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) - s.Require().Len(args, 10) + s.Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) - expectedArgs := []any{resources[0].(*temp).ID, resources[0].(*temp).Name, (*int)(nil), resources[1].(*temp).Name, resources[2].(*temp).ID, resources[2].(*temp).Name, resources[3].(*temp).ID, resources[3].(*temp).Name, resources[4].(*temp).ID, resources[4].(*temp).Name} + expectedArgs := []any{1, "test", interface{}(nil), "test2", 3, "test3", 4, "test4", 5, "test5"} s.Require().Equal(expectedArgs, args) } + +func (s *generateSQLSuite) TestGenerateSQL_Success_WithPointedFields_noDbTag() { + type temp struct { + ID *int + Name *string + unexported string + } + + resources := []any{ + &temp{ID: ptr(1), Name: ptr("test")}, + &temp{ID: nil, Name: ptr("test2")}, + &temp{ID: ptr(3), Name: ptr("test3")}, + &temp{ID: ptr(4), Name: ptr("test4")}, + &temp{ID: ptr(5), Name: ptr("test5"), unexported: "test"}, + } + + sql, args, err := NewBatch(resources, WithTable("temp")).GenerateSQL() + s.Require().NoError(err) + + s.Equal("INSERT INTO temp (ID, Name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) + + expectedArgs := []any{1, "test", interface{}(nil), "test2", 3, "test3", 4, "test4", 5, "test5"} + s.Require().Equal(expectedArgs, args) +} + +func (s *generateSQLSuite) TestGenerateSQL_Success_IgnoredFields() { + type temp struct { + ID int `db:"id"` + Name string `db:"name"` + unexported string `db:"unexported"` + } + + resources := []any{ + &temp{ID: 1, Name: "test"}, + &temp{ID: 2, Name: "test2"}, + &temp{ID: 3, Name: "test3"}, + &temp{ID: 4, Name: "test4"}, + &temp{ID: 5, Name: "test5", unexported: "test"}, + } + + b := NewBatch(resources, WithTable("temp"), WithTagName("db"), WithIgnoreFields("unexported")) + + sql, args, err := b.GenerateSQL() + s.Require().NoError(err) + + s.Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql) + s.Len(args, 10) +} + +func (s *generateSQLSuite) TestGenerateSQL_Success_IgnoredFieldsFunc() { + type temp struct { + ID int `db:"id"` + Name string `db:"name"` + unexported string `db:"unexported"` + } + + resources := []any{ + &temp{ID: 1, Name: "test"}, + &temp{ID: 2, Name: "test2"}, + &temp{ID: 3, Name: "test3"}, + &temp{ID: 4, Name: "test4"}, + &temp{ID: 5, Name: "test5", unexported: "test"}, + } + + mif := patcher.NewMockIgnoreFieldsFunc(s.T()) + mif.On("Execute", mock.Anything).Return(func(f reflect.StructField) bool { + return f.Name == "ID" + }) + + b := NewBatch(resources, WithTable("temp"), WithTagName("db"), WithIgnoreFieldsFunc(mif.Execute)) + + sql, args, err := b.GenerateSQL() + s.Require().NoError(err) + + s.Equal("INSERT INTO temp (name) VALUES (?), (?), (?), (?), (?)", sql) + s.Len(args, 5) +} diff --git a/loader.go b/loader.go index 13111e5..0f2e077 100644 --- a/loader.go +++ b/loader.go @@ -12,31 +12,6 @@ var ( ErrInvalidType = errors.New("invalid type: must pointer to struct") ) -func newPatchDefaults(opts ...PatchOpt) *SQLPatch { - // Default options - p := &SQLPatch{ - fields: nil, - args: nil, - db: nil, - tagName: defaultDbTagName, - table: "", - whereSql: new(strings.Builder), - whereArgs: nil, - joinSql: new(strings.Builder), - joinArgs: nil, - includeZeroValues: false, - includeNilValues: false, - ignoreFields: nil, - ignoreFieldsFunc: nil, - } - - for _, opt := range opts { - opt(p) - } - - return p -} - // LoadDiff inserts the fields provided in the new struct pointer into the old struct pointer and injects the new // values into the old struct // diff --git a/patch.go b/patch.go index 772ee17..4196677 100644 --- a/patch.go +++ b/patch.go @@ -69,6 +69,32 @@ type SQLPatch struct { ignoreFieldsFunc IgnoreFieldsFunc } +// newPatchDefaults creates a new SQLPatch with default options. +func newPatchDefaults(opts ...PatchOpt) *SQLPatch { + // Default options + p := &SQLPatch{ + fields: nil, + args: nil, + db: nil, + tagName: DefaultDbTagName, + table: "", + whereSql: new(strings.Builder), + whereArgs: nil, + joinSql: new(strings.Builder), + joinArgs: nil, + includeZeroValues: false, + includeNilValues: false, + ignoreFields: nil, + ignoreFieldsFunc: nil, + } + + for _, opt := range opts { + opt(p) + } + + return p +} + func (s *SQLPatch) Fields() []string { return s.fields } diff --git a/sql.go b/sql.go index 6a0b119..a725e4d 100644 --- a/sql.go +++ b/sql.go @@ -10,7 +10,7 @@ import ( ) const ( - defaultDbTagName = "db" + DefaultDbTagName = "db" ) var ( @@ -49,8 +49,7 @@ func (s *SQLPatch) patchGen(resource any) { tag := fType.Tag.Get(s.tagName) // Skip unexported fields - if fType.PkgPath != "" { - // This is an unexported field + if !fType.IsExported() { continue }