forked from pressly/goose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglobals.go
104 lines (98 loc) · 3.25 KB
/
globals.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package goose
import (
"errors"
"fmt"
"path/filepath"
)
var (
registeredGoMigrations = make(map[int64]*Migration)
)
// ResetGlobalMigrations resets the global Go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
}
// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the
// same version has already been registered. Go migrations must be constructed using the
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...*Migration) error {
for _, m := range migrations {
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkGoMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
registeredGoMigrations[m.Version] = m
}
return nil
}
func checkGoMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
if !m.Registered {
return errors.New("must be registered")
}
if m.Type != TypeGo {
return fmt.Errorf("type must be %q", TypeGo)
}
if m.Version < 1 {
return errors.New("version must be greater than zero")
}
if m.Source != "" {
if filepath.Ext(m.Source) != ".go" {
return fmt.Errorf("source must have .go extension: %q", m.Source)
}
// If the source is set, expect it to be a path with a numeric component that matches the
// version. This field is not intended to be used for descriptive purposes.
version, err := NumericComponent(m.Source)
if err != nil {
return fmt.Errorf("invalid source: %w", err)
}
if version != m.Version {
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
}
if err := checkGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := checkGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.UpFn != nil && m.UpFnNoTx != nil {
return errors.New("must specify exactly one of UpFn or UpFnNoTx")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
if m.DownFn != nil && m.DownFnNoTx != nil {
return errors.New("must specify exactly one of DownFn or DownFnNoTx")
}
return nil
}
func checkGoFunc(f *GoFunc) error {
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
switch f.Mode {
case TransactionEnabled, TransactionDisabled:
// No functions, but mode is set. This is not an error. It means the user wants to
// record a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
if f.RunDB != nil && f.Mode != TransactionDisabled {
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
if f.RunTx != nil && f.Mode != TransactionEnabled {
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
return nil
}