Skip to content

Commit

Permalink
Update xsql TagValuesMap
Browse files Browse the repository at this point in the history
  • Loading branch information
onanying committed Jul 1, 2024
1 parent 81eabdc commit 9cea85a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
6 changes: 4 additions & 2 deletions src/xsql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ Update specific columns by struct pointer

```go
test := Test{}
data, err := xsql.BuildTagValues(DB.Options.Tag, &test,
&test.Foo, "test",
data, err := xsql.TagValuesMap(DB.Options.Tag, &test,
xsql.TagValues{
{&test.Foo, "test"},
},
)
if err != nil {
log.Fatal(err)
Expand Down
16 changes: 10 additions & 6 deletions src/xsql/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,29 +271,33 @@ func TestUpdateColumns(t *testing.T) {
a.Empty(err)
}

func TestUpdateBuildTagValues(t *testing.T) {
func TestUpdateTagValuesMap(t *testing.T) {
a := assert.New(t)

DB := newDB()

test := Test{}
data, err := xsql.BuildTagValues(DB.Options.Tag, &test,
&test.Foo, "test_update_3",
data, err := xsql.TagValuesMap(DB.Options.Tag, &test,
xsql.TagValues{
{&test.Foo, "test_update_3"},
},
)
a.Empty(err)

_, err = DB.Model(&test).Update(data, "id = ?", 8)
a.Empty(err)
}

func TestEmbeddingUpdateBuildTagValues(t *testing.T) {
func TestEmbeddingUpdateTagValuesMap(t *testing.T) {
a := assert.New(t)

DB := newDB()

test := EmbeddingTest{}
data, err := xsql.BuildTagValues(DB.Options.Tag, &test,
&test.Foo, "test_update_4",
data, err := xsql.TagValuesMap(DB.Options.Tag, &test,
xsql.TagValues{
{&test.Foo, "test_update_4"},
},
)
a.Empty(err)

Expand Down
62 changes: 26 additions & 36 deletions src/xsql/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,69 +5,59 @@ import (
"reflect"
)

// BuildTagValues takes a tag key, a pointer to a struct, and a series of pointers to struct fields with their corresponding values.
// It returns a map where each key is the tag value of the struct field, and each value is the corresponding value from the pairs.
// If the number of arguments in pairs is not even, an error is returned since they should be provided as pointer-value pairs.
func BuildTagValues(tagKey string, ptr interface{}, pairs ...interface{}) (map[string]interface{}, error) {
if len(pairs)%2 != 0 {
return nil, fmt.Errorf("xsql: arguments should be in pairs")
}
type TagValues []TagValue

type TagValue struct {
Key interface{}
Value interface{}
}

// TagValuesMap takes a tag key, a pointer to a struct, and TagValues.
// It constructs a map where each key is the struct field's tag value, paired with the corresponding value from TagValues.
func TagValuesMap(tagKey string, ptr interface{}, values TagValues) (map[string]interface{}, error) {
result := make(map[string]interface{})
value := reflect.ValueOf(ptr).Elem()
structValue := reflect.ValueOf(ptr).Elem()

if value.Kind() != reflect.Struct {
if structValue.Kind() != reflect.Struct {
return nil, fmt.Errorf("xsql: ptr must be a pointer to a struct")
}

fieldsMap := map[string]reflect.Value{}
populateFieldsMap(tagKey, value, fieldsMap)

for i := 0; i < len(pairs); i += 2 {
fieldPtr, ok := pairs[i].(interface{})
if !ok {
return nil, fmt.Errorf("xsql: argument at index %d is not a pointer", i)
}
fieldsMap := make(map[string]reflect.Value)
populateFieldsMap(tagKey, structValue, fieldsMap)

fieldValue := reflect.ValueOf(fieldPtr)
for i, tagValue := range values {
fieldPtr, fieldValue := tagValue.Key, reflect.ValueOf(tagValue.Key)
if fieldValue.Kind() != reflect.Ptr || fieldValue.IsNil() {
return nil, fmt.Errorf("xsql: argument at index %d must be a non-nil pointer to a struct field", i)
return nil, fmt.Errorf("xsql: error at item %d in values slice: key is not a non-nil pointer to a struct field", i)
}

var fieldName string
var found bool
for name, field := range fieldsMap {
foundFieldName := ""
for tagName, field := range fieldsMap {
if field.Addr().Interface() == fieldPtr {
fieldName = name
found = true
foundFieldName = tagName
break
}
}

if !found {
return nil, fmt.Errorf("xsql: no matching struct field found for pointer at index %d", i)
if foundFieldName == "" {
return nil, fmt.Errorf("xsql: no matching struct field found for item %d in values slice", i)
}

result[fieldName] = pairs[i+1]
result[foundFieldName] = tagValue.Value
}

return result, nil
}

// populateFieldsMap is a recursive function that maps field names to their values,
// including fields from embedded structs.
func populateFieldsMap(tagKey string, v reflect.Value, fieldsMap map[string]reflect.Value) {
for i := 0; i < v.NumField(); i++ {
fieldValue := v.Field(i)
field := v.Field(i)
fieldType := v.Type().Field(i)
tag := fieldType.Tag.Get(tagKey)
// If it's an embedded struct, we need to recurse into it
if fieldType.Anonymous && fieldValue.Type().Kind() == reflect.Struct {
populateFieldsMap(tagKey, fieldValue, fieldsMap)
if fieldType.Anonymous && field.Type().Kind() == reflect.Struct {
populateFieldsMap(tagKey, field, fieldsMap)
} else if tag != "" {
// Only add the field if it has the xsql tag
fieldName := tag
fieldsMap[fieldName] = fieldValue
fieldsMap[tag] = field
}
}
}

0 comments on commit 9cea85a

Please sign in to comment.