Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleconroy committed Jan 4, 2024
1 parent 8af8f79 commit 253a45a
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 14 deletions.
2 changes: 2 additions & 0 deletions internal/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type tmplCtx struct {
EmitAllEnumValues bool
UsesCopyFrom bool
UsesBatch bool
OmitSqlcVersion bool
BuildTags string
}

Expand Down Expand Up @@ -185,6 +186,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
Structs: structs,
SqlcVersion: req.SqlcVersion,
BuildTags: options.BuildTags,
OmitSqlcVersion: options.OmitSqlcVersion,
}

if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL {
Expand Down
2 changes: 1 addition & 1 deletion internal/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin
case "postgresql":
return postgresType(req, options, col)
case "sqlite":
return sqliteType(req, col)
return sqliteType(req, options, col)
default:
return "interface{}"
}
Expand Down
2 changes: 1 addition & 1 deletion internal/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ func (i *importer) queryImports(filename string) fileImports {
}

sqlpkg := parseDriver(i.Options.SqlPackage)
if sqlcSliceScan() {
if sqlcSliceScan() && !sqlpkg.IsPGX() {
std["strings"] = struct{}{}
}
if sliceScan() && !sqlpkg.IsPGX() {
Expand Down
4 changes: 3 additions & 1 deletion internal/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Options struct {
EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"`
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
Package string `json:"package" yaml:"package"`
Out string `json:"out" yaml:"out"`
Expand All @@ -34,11 +35,12 @@ type Options struct {
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_queries_file_name"`
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
OmitSqlcVersion bool `json:"omit_sqlc_version,omitempty" yaml:"omit_sqlc_version"`
OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"`
BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"`
}
Expand Down
19 changes: 18 additions & 1 deletion internal/result.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package golang

import (
"bufio"
"fmt"
"sort"
"strings"
Expand Down Expand Up @@ -199,14 +200,30 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
constantName = sdk.LowerTitle(query.Name)
}

comments := query.Comments
if options.EmitSqlAsComment {
if len(comments) == 0 {
comments = append(comments, query.Name)
}
comments = append(comments, " ")
scanner := bufio.NewScanner(strings.NewReader(query.Text))
for scanner.Scan() {
line := scanner.Text()
comments = append(comments, " "+line)
}
if err := scanner.Err(); err != nil {
return nil, err
}
}

gq := Query{
Cmd: query.Cmd,
ConstantName: constantName,
FieldName: sdk.LowerTitle(query.Name) + "Stmt",
MethodName: query.Name,
SourceName: query.Filename,
SQL: query.Text,
Comments: query.Comments,
Comments: comments,
Table: query.InsertIntoTable,
}
sqlpkg := parseDriver(options.SqlPackage)
Expand Down
22 changes: 21 additions & 1 deletion internal/sqlite_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@ import (
"github.com/sqlc-dev/plugin-sdk-go/plugin"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/sqlc-gen-go/internal/debug"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
)

func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
func sqliteType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
dt := strings.ToLower(sdk.DataType(col.Type))
notNull := col.NotNull || col.IsArray
emitPointersForNull := options.EmitPointersForNullTypes

switch dt {

case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8":
if notNull {
return "int64"
}
if emitPointersForNull {
return "*int64"
}
return "sql.NullInt64"

case "blob":
Expand All @@ -28,18 +33,27 @@ func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
if notNull {
return "float64"
}
if emitPointersForNull {
return "*float64"
}
return "sql.NullFloat64"

case "boolean", "bool":
if notNull {
return "bool"
}
if emitPointersForNull {
return "*bool"
}
return "sql.NullBool"

case "date", "datetime", "timestamp":
if notNull {
return "time.Time"
}
if emitPointersForNull {
return "*time.Time"
}
return "sql.NullTime"

case "any":
Expand All @@ -60,12 +74,18 @@ func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
if notNull {
return "string"
}
if emitPointersForNull {
return "*string"
}
return "sql.NullString"

case strings.HasPrefix(dt, "decimal"), dt == "numeric":
if notNull {
return "float64"
}
if emitPointersForNull {
return "*float64"
}
return "sql.NullFloat64"

default:
Expand Down
21 changes: 12 additions & 9 deletions internal/templates/template.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
{{end}}

package {{.Package}}

Expand Down Expand Up @@ -33,8 +34,9 @@ import (
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
{{end}}

package {{.Package}}

Expand All @@ -61,8 +63,9 @@ import (
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
{{end}}

package {{.Package}}

Expand Down Expand Up @@ -158,9 +161,9 @@ type {{.Name}} struct { {{- range .Fields}}
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
// source: {{.SourceName}}
{{end}}// source: {{.SourceName}}

package {{.Package}}

Expand All @@ -187,9 +190,9 @@ import (
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
// source: {{.SourceName}}
{{end}}// source: {{.SourceName}}

package {{.Package}}

Expand All @@ -216,9 +219,9 @@ import (
//go:build {{.BuildTags}}

{{end}}// Code generated by sqlc. DO NOT EDIT.
// versions:
{{if not .OmitSqlcVersion}}// versions:
// sqlc {{.SqlcVersion}}
// source: {{.SourceName}}
{{end}}// source: {{.SourceName}}

package {{.Package}}

Expand Down

0 comments on commit 253a45a

Please sign in to comment.