Skip to content

Commit

Permalink
Merge pull request #2 from thisissoon/feature/import-config
Browse files Browse the repository at this point in the history
Feature: Import config
  • Loading branch information
krak3n authored Apr 3, 2018
2 parents ad03a47 + 1de6102 commit 14e9e2f
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 19 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ scaneo [options] paths...
-f, -funcs
Generate SQL helper functions. Default is false.
-i, -import
Override package to import type from.
-v, -version
Print version and exit.
Expand Down
29 changes: 25 additions & 4 deletions scaneo.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ OPTIONS
-f, -funcs
Generate SQL helper functions.
-i, -import
Override package to import type from.
-v, -version
Print version and exit.
-h, -help
Print help and exit.
Expand Down Expand Up @@ -89,13 +91,15 @@ func main() {
unexport := flag.Bool("u", false, "")
whitelist := flag.String("w", "", "")
genFuncs := flag.Bool("f", false, "")
pkgImport := flag.String("i", "", "")
version := flag.Bool("v", false, "")
help := flag.Bool("h", false, "")
flag.StringVar(outFilename, "output", "scans.go", "")
flag.StringVar(packName, "package", "current directory", "")
flag.BoolVar(unexport, "unexport", false, "")
flag.StringVar(whitelist, "whitelist", "", "")
flag.BoolVar(genFuncs, "funcs", false, "")
flag.StringVar(pkgImport, "import", "", "")
flag.BoolVar(version, "version", false, "")
flag.BoolVar(help, "help", false, "")
flag.Usage = func() { log.Println(usageText) } // call on flag error
Expand Down Expand Up @@ -139,7 +143,7 @@ func main() {
structToks = append(structToks, toks...)
}

if err := genFile(*outFilename, *packName, *unexport, structToks, *genFuncs); err != nil {
if err := genFile(*outFilename, *packName, *unexport, structToks, *genFuncs, *pkgImport); err != nil {
log.Fatal("couldn't generate file:", err)
}
}
Expand Down Expand Up @@ -347,7 +351,7 @@ func parseStar(fieldType *ast.StarExpr) string {
return fmt.Sprintf("*%s", starType)
}

func genFile(outFile, pkg string, unexport bool, toks []structToken, genFuncs bool) error {
func genFile(outFile, pkg string, unexport bool, toks []structToken, genFuncs bool, pkgImport string) error {
if len(toks) < 1 {
return errors.New("no structs found")
}
Expand All @@ -363,19 +367,36 @@ func genFile(outFile, pkg string, unexport bool, toks []structToken, genFuncs bo
Tokens []structToken
Visibility string
Funcs bool
ImportPkg string
}{
PackageName: pkg,
Visibility: "S",
Tokens: toks,
Funcs: genFuncs,
ImportPkg: pkgImport,
}

if unexport {
// func name will be scanFoo instead of ScanFoo
data.Visibility = "s"
}

fnMap := template.FuncMap{"title": strings.Title}
// Construct type prefix from pkgImport
var typePrefix string
if pkgImport != "" {
pkgPathParts := strings.Split(pkgImport, "/")
typePrefix = pkgPathParts[len(pkgPathParts)-1]
}

fnMap := template.FuncMap{
"title": strings.Title,
"pkg": func(s string) string {
if typePrefix != "" {
return fmt.Sprintf("%s.%s", typePrefix, s)
}
return s
},
}
scansTmpl, err := template.New("scans").Funcs(fnMap).Parse(scansText)
if err != nil {
return err
Expand Down
22 changes: 21 additions & 1 deletion scaneo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func TestGenFile(t *testing.T) {
tokens []structToken
unexport bool
funcs bool
pkgImport string
assert func(*testing.T, error)
expectedFuncs []string
}{
Expand All @@ -323,6 +324,7 @@ func TestGenFile(t *testing.T) {
[]structToken{},
true,
false,
"",
func(t *testing.T, err error) {
if err == nil {
t.Error("no struct tokens passed")
Expand All @@ -338,6 +340,7 @@ func TestGenFile(t *testing.T) {
toks,
true,
false,
"",
func(t *testing.T, err error) {
if err == nil {
t.Error("no output file path passed")
Expand All @@ -353,6 +356,7 @@ func TestGenFile(t *testing.T) {
toks,
true,
false,
"",
func(t *testing.T, err error) {
if err != nil {
t.Error(err)
Expand All @@ -367,6 +371,7 @@ func TestGenFile(t *testing.T) {
toks,
true,
true,
"",
func(t *testing.T, err error) {
if err != nil {
t.Error(err)
Expand All @@ -390,6 +395,21 @@ func TestGenFile(t *testing.T) {
"UpdateUnexported",
},
},
{
"pkg import",
true,
toks,
true,
false,
"testsvc/storage/user",
func(t *testing.T, err error) {
if err != nil {
t.Error(err)
t.FailNow()
}
},
expectedFuncNames,
},
}

for _, tc := range tt {
Expand All @@ -401,7 +421,7 @@ func TestGenFile(t *testing.T) {
fmt.Sprintf("scaneo-test-%d", time.Now().UnixNano()))
}
// genFile(file, package, unexport, tokens, funcs)
err := genFile(outFile, "testing", tc.unexport, tc.tokens, tc.funcs)
err := genFile(outFile, "testing", tc.unexport, tc.tokens, tc.funcs, tc.pkgImport)
defer os.Remove(outFile) // comment this line to examine generated code

tc.assert(t, err)
Expand Down
30 changes: 16 additions & 14 deletions tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@ const (
package {{.PackageName}}
import "database/sql"
{{range .Tokens}}func {{$.Visibility}}can{{title .Name}}(r *sql.Row) (*{{.Name}}, error) {
var s *{{.Name}}
import (
"database/sql"{{ if .ImportPkg }}
"{{.ImportPkg}}"{{end}}
)
{{range .Tokens}}
func {{$.Visibility}}can{{title .Name}}(r *sql.Row) (*{{pkg .Name}}, error) {
var s *{{pkg .Name}}
if err := r.Scan({{range .Fields}}
&s.{{.Name}},{{end}}
); err != nil {
return &{{.Name}}{}, err
return &{{pkg .Name}}{}, err
}
return s, nil
}
func {{$.Visibility}}can{{title .Name}}s(rs *sql.Rows) ([]*{{.Name}}, error) {
structs := make([]*{{.Name}}, 0, 16)
func {{$.Visibility}}can{{title .Name}}s(rs *sql.Rows) ([]*{{pkg .Name}}, error) {
structs := make([]*{{pkg .Name}}, 0, 16)
var err error
for rs.Next() {
var s *{{.Name}}
var s *{{pkg .Name}}
if err = rs.Scan({{range .Fields}}
&s.{{.Name}},{{end}}
); err != nil {
Expand All @@ -36,13 +39,13 @@ func {{$.Visibility}}can{{title .Name}}s(rs *sql.Rows) ([]*{{.Name}}, error) {
}
{{if $.Funcs}}
// Select{{title .Name}} selects a single {{title .Name}} row from the database
func Select{{title .Name}}(db *sql.DB, query string, args ...interface{}) (*{{title .Name}}, error) {
func Select{{title .Name}}(db *sql.DB, query string, args ...interface{}) (*{{pkg .Name}}, error) {
row := db.QueryRow(query, args...)
return {{$.Visibility}}can{{title .Name}}(row)
}
// Select{{title .Name}}s selects multiple {{title .Name}} rows from the database
func Select{{title .Name}}s(db *sql.DB, query string, args ...interface{}) ([]*{{title .Name}}, error) {
func Select{{title .Name}}s(db *sql.DB, query string, args ...interface{}) ([]*{{pkg .Name}}, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
Expand All @@ -52,25 +55,24 @@ func Select{{title .Name}}s(db *sql.DB, query string, args ...interface{}) ([]*{
}
// slice{{title .Name}} returns a slice of arguments from {{title .Name}} struct values
func slice{{title .Name}}(v *{{title .Name}}) []interface{} {
func slice{{title .Name}}(v *{{pkg .Name}}) []interface{} {
return []interface{}{ {{range .Fields}}
&v.{{.Name}},{{end}}
}
}
// Insert{{title .Name}} inserts a single {{title .Name}} row
func Insert{{title .Name}}(db *sql.DB, query string, v *{{title .Name}}) error {
func Insert{{title .Name}}(db *sql.DB, query string, v *{{pkg .Name}}) error {
_, err := db.Exec(query, slice{{title .Name}}(v)[1:]...)
return err
}
// Update{{title .Name}} updates a single {{title .Name}} row
func Update{{title .Name}}(db *sql.DB, query string, v *{{title .Name}}) error {
func Update{{title .Name}}(db *sql.DB, query string, v *{{pkg .Name}}) error {
args := slice{{title .Name}}(v)[1:]
args = append(args, v.ID)
_, err := db.Exec(query, args...)
return err
}
{{end}}{{end}}{{end}}`
)

0 comments on commit 14e9e2f

Please sign in to comment.