-
Notifications
You must be signed in to change notification settings - Fork 0
/
context.go
76 lines (65 loc) · 1.3 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package dbtxn
import (
"context"
"database/sql"
"errors"
"strings"
"go.uber.org/multierr"
)
type (
// Context of transaction
Context struct {
TxMap map[*sql.DB]Tx
Errs []error
}
)
// NewContext return new instance of Context
func NewContext() *Context {
return &Context{TxMap: make(map[*sql.DB]Tx)}
}
// Begin transaction
func (c *Context) Begin(ctx context.Context, db *sql.DB) (DB, error) {
tx, ok := c.TxMap[db]
if ok {
return tx, nil
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
c.AppendError(err)
return nil, err
}
c.TxMap[db] = tx
return tx, nil
}
// Commit if no error
func (c *Context) Commit() error {
var errMsgs []string
if len(c.Errs) > 0 {
for _, tx := range c.TxMap {
if err := tx.Rollback(); err != nil {
errMsgs = append(errMsgs, err.Error())
}
}
} else {
for _, tx := range c.TxMap {
if err := tx.Commit(); err != nil {
errMsgs = append(errMsgs, err.Error())
}
}
}
if msg := strings.Join(errMsgs, ErrSep); msg != "" {
return errors.New(msg)
}
return nil
}
func (c *Context) CommitWithError(err *error) {
*err = multierr.Append(*err, c.Commit())
}
// AppendError to append error to txn context
func (c *Context) AppendError(err error) bool {
if c != nil && err != nil {
c.Errs = append(c.Errs, err)
return true
}
return false
}