-
Notifications
You must be signed in to change notification settings - Fork 0
/
sqload.go
87 lines (68 loc) · 1.38 KB
/
sqload.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package sqload
import (
"bytes"
"embed"
"io/fs"
)
type Loader struct {
d Dialector
}
func New(dialector Dialector) Loader {
return Loader{d: dialector}
}
func (l Loader) Load(content *embed.FS, to *bytes.Buffer) error {
fileNames, err := getAllFilenames(content)
if err != nil {
return err
}
if err := load(content, fileNames, to); err != nil {
return err
}
return nil
}
func (l Loader) LoadFrom(content *embed.FS, to *bytes.Buffer, fileNames ...string) error {
if err := load(content, fileNames, to); err != nil {
return err
}
return nil
}
func (l Loader) Parse(sqlfile string, to *[]string) error {
var (
sqls []string
err error
)
if dialector, ok := l.d.(interface {
Parse(string, *[]string) error
}); ok {
if err = dialector.Parse(sqlfile, &sqls); err != nil {
return err
}
}
*to = sqls
return nil
}
func getAllFilenames(efs *embed.FS) ([]string, error) {
var files []string
if err := fs.WalkDir(efs, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
files = append(files, path)
return nil
}); err != nil {
return nil, err
}
return files, nil
}
func load(content *embed.FS, fileNames []string, to *bytes.Buffer) error {
var buf bytes.Buffer
for _, name := range fileNames {
file, err := content.ReadFile(name)
if err != nil {
return err
}
buf.Write(file)
}
*to = buf
return nil
}