-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsession.go
134 lines (127 loc) · 3.52 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package orm
import (
"context"
"database/sql"
"fmt"
"reflect"
)
type Session struct {
db ExecutorIfc
}
func (s *Session) Table(schema Schema) *Action {
return &Action{
session: s,
schema: schema,
}
}
var payloadIfcType = reflect.TypeOf((*PayloadIfc)(nil)).Elem()
func (s *Session) queryPayload(ctx context.Context, stmt *Stmt, payloadRef PayloadIfc, nestedPayloadRef ...any) error {
// TODO: 自动识别 payload 嵌套, 或者使用 nestPayloadRef 指定
bindFields := boundFields(payloadRef)
for _, item := range nestedPayloadRef {
itemV := reflect.ValueOf(item)
if itemV.Type().Kind() != reflect.Ptr {
return fmt.Errorf("payload must be pointer")
}
if itemV.Type().Implements(payloadIfcType) {
p := item.(PayloadIfc)
bindFields = append(bindFields, boundFields(p)...)
} else if itemV.Type().Elem().Kind() == reflect.Ptr &&
itemV.Type().Elem().Implements(payloadIfcType) {
if itemV.Elem().IsNil() && itemV.Elem().CanSet() {
newItem := reflect.New(itemV.Type().Elem().Elem())
itemV.Elem().Set(newItem)
}
itemDef := itemV.Elem().Interface()
p := itemDef.(PayloadIfc)
bindFields = append(bindFields, boundFields(p)...)
} else {
return fmt.Errorf("nestedPayloadRef must be PayloadIfc, find :%T", item)
}
}
fields := []FieldIfc{}
for _, field := range bindFields {
fields = append(fields, field.field)
}
stmt.selectField = fields
expr, err := stmt.completeSelect()
if err != nil {
return err
}
sqlRaw, argsRaw := expr.Expr()
rows, err := s.db.QueryContext(ctx, sqlRaw, argsRaw...)
if err != nil {
return err
}
for rows.Next() {
values := make([]any, 0, len(bindFields))
for _, field := range bindFields {
values = append(values, field.RefVal())
}
if err := rows.Scan(values...); err != nil {
return err
}
for _, field := range bindFields {
field.setPreVal(field.Val())
}
}
return nil
}
func (s *Session) queryPayloadSlice(ctx context.Context, stmt *Stmt, payloadSliceRef any) error {
rv := reflect.ValueOf(payloadSliceRef)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("must be ptr, find :%T", rv.Interface())
}
rvElem := rv.Elem()
if rvElem.Kind() != reflect.Slice {
return fmt.Errorf("must be slice, find :%T", rvElem.Interface())
}
newPayload := reflect.New(rvElem.Type().Elem().Elem())
p, ok := newPayload.Interface().(PayloadIfc)
if !ok {
return fmt.Errorf("must be PayloadIfc, find :%T", newPayload.Interface())
}
bindFields := boundFields(p)
fields := []FieldIfc{}
for _, field := range bindFields {
fields = append(fields, field.field)
}
stmt.selectField = fields
expr, err := stmt.completeSelect()
if err != nil {
return err
}
sqlRaw, argsRaw := expr.Expr()
rows, err := s.db.QueryContext(ctx, sqlRaw, argsRaw...)
if err != nil {
return err
}
for rows.Next() {
rvPayload := reflect.New(rvElem.Type().Elem().Elem())
p, ok := rvPayload.Interface().(PayloadIfc)
if !ok {
return fmt.Errorf("must be PayloadIfc, find :%T", rvPayload.Interface())
}
bindFields := boundFields(p)
values := make([]any, 0, len(bindFields))
for _, field := range bindFields {
values = append(values, field.RefVal())
}
if err := rows.Scan(values...); err != nil {
return err
}
for _, field := range bindFields {
field.setPreVal(field.Val())
}
rvElem.Set(reflect.Append(rvElem, rvPayload))
}
return nil
}
func (s *Session) exec(ctx context.Context, stmt *Stmt) (sql.Result, error) {
expr, err := stmt.complete()
if err != nil {
return nil, err
}
sqlRaw, argsRaw := expr.Expr()
return s.db.ExecContext(ctx, sqlRaw, argsRaw...)
}