-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeeorm.go
147 lines (131 loc) · 4.28 KB
/
geeorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Engine 是 GeeORM 与用户交互的入口
// 负责交互前的准备工作(比如连接/测试数据库),交互后的收尾工作(关闭连接)等
// 封装了数据库连接的创建、关闭以及会话管理的功能
package geeorm
import (
"database/sql"
"fmt"
"strings"
"geeorm/dialect"
"geeorm/log"
"geeorm/session"
)
// Engine 结构体是整个库的入口点,用于与数据库进行交互。它包含一个指向数据库连接的指针
type Engine struct {
db *sql.DB
dialect dialect.Dialect
}
// NewEngine 构造函数,用于创建并初始化一个 Engine 实例
func NewEngine(driver, source string) (e *Engine, err error) {
db, err := sql.Open(driver, source)
if err != nil {
log.Error(err)
return
}
// Send a ping to make sure the database connection is alive.
if err = db.Ping(); err != nil {
log.Error(err)
return
}
// make sure the specific dialect exists
dial, ok := dialect.GetDialect(driver)
if !ok {
log.Errorf("dialect %s Not Found", driver)
return
}
e = &Engine{db: db, dialect: dial}
log.Info("Connect database success")
return
}
// Close 关闭数据库连接
func (engine *Engine) Close() {
if err := engine.db.Close(); err != nil {
log.Error("Failed to close database")
}
log.Info("Close database success")
}
// NewSession 创建一个新的数据库会话,该会话将用于执行数据库操作
func (engine *Engine) NewSession() *session.Session {
return session.New(engine.db, engine.dialect)
}
// 自定义的函数类型,它接受一个 *session.Session 类型的参数,并返回一个 interface{} 类型的结果和一个 error 类型的错误。
// 这个函数类型用于表示一个数据库事务操作
type TxFunc func(*session.Session) (interface{}, error)
// 在 geeorm.go 中为用户提供傻瓜式/一键式使用的事务接口
// 用于执行数据库事务
func (engine *Engine) Transaction(f TxFunc) (result interface{}, err error) {
s := engine.NewSession()
if err := s.Begin(); err != nil {
return nil, err
}
defer func() {
// 首先,使用 recover() 函数来捕获可能在事务操作中引发的panic
if p := recover(); p != nil {
_ = s.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
_ = s.Rollback() // err is non-nil; don't change it
} else {
// commit失败需要再回滚一次
defer func() {
if err != nil {
_ = s.Rollback()
}
}()
err = s.Commit() // err is nil; if Commit returns error update err
}
}()
return f(s)
}
// difference returns a - b
func difference(a []string, b []string) (diff []string) {
mapB := make(map[string]bool)
for _, v := range b {
mapB[v] = true
}
for _, v := range a {
if _, ok := mapB[v]; !ok {
diff = append(diff, v)
}
}
return
}
// Migrate 执行数据库迁移操作
func (engine *Engine) Migrate(value interface{}) error {
// 接受一个回调函数,该函数在一个数据库事务中执行
_, err := engine.Transaction(func(s *session.Session) (result interface{}, err error) {
// 首先检查数据库中是否存在与给定结构体对应的表格
if !s.Model(value).HasTable() {
log.Infof("table %s doesn't exist", s.RefTable().Name)
return nil, s.CreateTable()
}
// 如果表格存在,获取表格的结构信息
// 这一步获取关联的go结构体的信息
table := s.RefTable()
// 利用查询获取了查询结果的列名,即数据库中表的字段名
rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows()
columns, _ := rows.Columns()
addCols := difference(table.FieldNames, columns)
delCols := difference(columns, table.FieldNames)
log.Infof("added cols %v, deleted cols %v", addCols, delCols)
for _, col := range addCols {
f := table.GetField(col)
sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type)
if _, err = s.Raw(sqlStr).Exec(); err != nil {
return
}
}
if len(delCols) == 0 {
return
}
tmp := "tmp_" + table.Name
fieldStr := strings.Join(table.FieldNames, ", ")
// 获取需要保存的字段
s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name))
s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name))
s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name))
_, err = s.Exec()
return
})
return err
}