diff --git a/cmd/generate.go b/cmd/generate.go index a131e39..925a47e 100755 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -13,6 +13,7 @@ const ( keyDebug = "debug" keyConnectionString = "connectionString" keyConfig = "config" + keyUseRoutinesFile = "useRoutinesFile" ) var generateCmd = &cobra.Command{ @@ -21,6 +22,7 @@ var generateCmd = &cobra.Command{ Long: "Generate code for calling database stored procedures", Run: func(cmd *cobra.Command, args []string) { common.BindBoolFlag(cmd, keyDebug) + common.BindBoolFlag(cmd, keyUseRoutinesFile) common.BindStringFlag(cmd, keyConnectionString) common.BindStringFlag(cmd, keyConfig) @@ -43,6 +45,7 @@ func init() { // set cli flags common.DefineBoolFlag(generateCmd, keyDebug, "d", false, "Print debug logs and create debug files") + common.DefineBoolFlag(generateCmd, keyUseRoutinesFile, "", false, "Use routines file to generate code") common.DefineStringFlag(generateCmd, keyConnectionString, "s", "", "Connection string used to connect to database") common.DefineStringFlag(generateCmd, keyConfig, "c", "", "Path to configuration file") } @@ -57,14 +60,10 @@ func doGenerate() error { common.LogDebug("Debug logging is enabled") - log.Printf("Connecting to database...") - conn, err := dbGen.Connect(config.ConnectionString) - if err != nil { - return fmt.Errorf("error connecting to database: %s", err) - } + var routines []dbGen.DbRoutine log.Printf("Getting routines...") - routines, err := dbGen.GetRoutines(conn, config) + routines, err = dbGen.GetRoutines(config) if err != nil { return fmt.Errorf("error getting routines: %s", err) } diff --git a/cmd/getRoutines.go b/cmd/getRoutines.go index 0ca8716..a7cabb4 100755 --- a/cmd/getRoutines.go +++ b/cmd/getRoutines.go @@ -15,6 +15,7 @@ var getRoutinesCmd = &cobra.Command{ Long: "Get routines from database to generate later", Run: func(cmd *cobra.Command, args []string) { common.BindBoolFlag(cmd, keyDebug) + common.BindBoolFlag(cmd, keyUseRoutinesFile) common.BindStringFlag(cmd, keyConnectionString) common.BindStringFlag(cmd, keyConfig) @@ -42,6 +43,7 @@ func init() { // set cli flags common.DefineBoolFlag(getRoutinesCmd, keyDebug, "d", false, "Print debug logs and create debug files") + common.DefineBoolFlag(getRoutinesCmd, keyUseRoutinesFile, "", false, "Use routines file to generate code") common.DefineStringFlag(getRoutinesCmd, keyConnectionString, "s", "", "Connection string used to connect to database") common.DefineStringFlag(getRoutinesCmd, keyConfig, "c", "", "Path to configuration file") } @@ -56,14 +58,10 @@ func doGetRoutines() error { common.LogDebug("Debug logging is enabled") - log.Printf("Connecting to database...") - conn, err := dbGen.Connect(config.ConnectionString) - if err != nil { - return fmt.Errorf("error connecting to database: %s", err) - } - + // because we use shared config, we need to set this + config.UseRoutinesFile = false log.Printf("Getting routines...") - routines, err := dbGen.GetRoutines(conn, config) + routines, err := dbGen.GetRoutines(config) if err != nil { return fmt.Errorf("error getting routines: %s", err) } diff --git a/common/fs.go b/common/fs.go index 78a535d..0195feb 100755 --- a/common/fs.go +++ b/common/fs.go @@ -59,3 +59,19 @@ func SaveAsJson(path string, data interface{}) error { return nil } + +func LoadFromJson(path string, data interface{}) error { + LogDebug("Loading data from json file %s", path) + + fileContent, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("reading file: %s", err) + } + + err = json.Unmarshal(fileContent, data) + if err != nil { + return fmt.Errorf("parsing json: %s", err) + } + return nil + +} diff --git a/src/config.go b/src/config.go index 1d76b47..f6dd812 100755 --- a/src/config.go +++ b/src/config.go @@ -31,6 +31,7 @@ type Config struct { Debug bool `mapstructure:"Debug"` ClearOutputFolder bool `mapstructure:"ClearOutputFolder"` RoutinesFile string `mapstructure:"RoutinesFile"` + UseRoutinesFile bool `mapstructure:"UseRoutinesFile"` Generate []SchemaConfig `mapstructure:"Generate"` Mappings []Mapping `mapstructure:"Mappings"` } diff --git a/src/database.go b/src/database.go index 53e97f8..52fa6e9 100755 --- a/src/database.go +++ b/src/database.go @@ -2,37 +2,11 @@ package dbGen import ( "fmt" - _ "github.com/jackc/pgx/v5/stdlib" - "github.com/jmoiron/sqlx" + "github.com/keenmate/db-gen/common" + "log" "slices" ) -type DbConn struct { - conn *sqlx.DB -} - -func Connect(connectionString string) (*DbConn, error) { - connection, err := sqlx.Connect("pgx", connectionString) - if err != nil { - return nil, err - } - - err = connection.Ping() - - if err != nil { - return nil, err - } - - conn := &DbConn{conn: connection} - - return conn, nil -} - -func (conn DbConn) Select(output interface{}, query string, params ...interface{}) error { - err := conn.conn.Select(output, query, params...) - return err -} - type DbRoutine struct { RowNumber int `db:"row_number"` RoutineSchema string `db:"routine_schema"` @@ -61,7 +35,32 @@ const ( Procedure = "procedure" ) -func GetRoutines(conn *DbConn, config *Config) ([]DbRoutine, error) { +func GetRoutines(config *Config) ([]DbRoutine, error) { + if config.UseRoutinesFile { + return LoadRoutinesFromFile(config) + } + + return getRoutinesFromDatabase(config) +} + +func LoadRoutinesFromFile(config *Config) ([]DbRoutine, error) { + routines := new([]DbRoutine) + err := common.LoadFromJson(config.RoutinesFile, routines) + if err != nil { + return nil, fmt.Errorf("loading routines from file: %s", err) + } + + return *routines, nil +} + +func getRoutinesFromDatabase(config *Config) ([]DbRoutine, error) { + log.Printf("Connecting to database...") + conn, err := common.Connect(config.ConnectionString) + if err != nil { + + return nil, fmt.Errorf("error connecting to database: %s", err) + } + schemas := getSchemas(config) routines := make([]DbRoutine, 0) @@ -86,6 +85,7 @@ func GetRoutines(conn *DbConn, config *Config) ([]DbRoutine, error) { return routines, nil } + func getSchemas(config *Config) []string { schemas := make([]string, 0) for _, schemaConfig := range config.Generate { @@ -97,7 +97,7 @@ func getSchemas(config *Config) []string { return schemas } -func getFunctionsInSchema(conn *DbConn, schema string) ([]DbRoutine, error) { +func getFunctionsInSchema(conn *common.DbConn, schema string) ([]DbRoutine, error) { routines := new([]DbRoutine) // I am coalescing @@ -127,7 +127,7 @@ func getFunctionsInSchema(conn *DbConn, schema string) ([]DbRoutine, error) { return *routines, nil } -func addParamsToRoutine(conn *DbConn, routine *DbRoutine) error { +func addParamsToRoutine(conn *common.DbConn, routine *DbRoutine) error { q := ` select ordinal_position::int, parameter_name::text, diff --git a/src/generator.go b/src/generator.go index 2e9a71d..b5e2c5d 100755 --- a/src/generator.go +++ b/src/generator.go @@ -22,37 +22,41 @@ func Generate(routines []Routine, config *Config) error { } common.LogDebug("Got %d file hashes", len(*fileHashes)) + log.Printf("Ensuring output folder...") + err = ensureOutputFolder(config) if err != nil { return fmt.Errorf("ensuring output folder: %s", err) } - log.Printf("Ensured output folder") + + log.Printf("Generating dbcontext...") err = generateDbContext(routines, fileHashes, config) if err != nil { return fmt.Errorf("generating dbcontext: %s", err) } - log.Printf("Generated dbcontext") if config.GenerateModels { + log.Printf("Generating models...") + err = generateModels(routines, fileHashes, config) if err != nil { return fmt.Errorf("generating models: %s", err) } - log.Printf("Generated models") } else { log.Printf("Skipping generating models") } if config.GenerateProcessors { + log.Printf("Generating processors...") + err = generateProcessors(routines, fileHashes, config) if err != nil { return fmt.Errorf("generating processors: %s", err) } - log.Printf("Generated processors") } else { log.Printf("Skipping generating processors") } diff --git a/testing/db-gen.json b/testing/db-gen.json index f7f4f36..424bbad 100755 --- a/testing/db-gen.json +++ b/testing/db-gen.json @@ -1,137 +1,137 @@ { - "OutputFolder": "C:\\Testbench\\CSharp\\AspNetCoreDbGen\\AspNetCoreDbGen\\output", - "GenerateModels": true, - "GenerateProcessors": true, - "GenerateProcessorsForVoidReturns": false, - "ClearOutputFolder": false, - "DbContextTemplate": "./templates/dbcontext.gotmpl", - "ModelTemplate": "./templates/model.gotmpl", - "ProcessorTemplate": "./templates/processor.gotmpl", - "GeneratedFileExtension": ".cs", - "GeneratedFileCase": "camelcase", - "Generate": [ - { - "Schema": "public", - "AllFunctions": true, - "IgnoredFunctions": [ - "ignored" - ] - }, - { - "Schema": "test", - "AllFunctions": false, - "Functions": [ - "explicitly_included" - ] - } - ], - "Mappings": [ - { - "DatabaseTypes": [ - "boolean", - "bool" - ], - "MappedType": "bool", - "MappingFunction": "GetBoolean" - }, - { - "DatabaseTypes": [ - "smallint", - "int2" - ], - "MappedType": "short", - "MappingFunction": "GetInt16" - }, - { - "DatabaseTypes": [ - "integer", - "int4" - ], - "MappedType": "int", - "MappingFunction": "GetInt32" - }, - { - "DatabaseTypes": [ - "bigint", - "int8" - ], - "MappedType": "long", - "MappingFunction": "GetInt64" - }, - { - "DatabaseTypes": [ - "real", - "float4" - ], - "MappedType": "float", - "MappingFunction": "GetFloat" - }, - { - "DatabaseTypes": [ - "double precision", - "float8" - ], - "MappedType": "double", - "MappingFunction": "GetDouble" - }, - { - "DatabaseTypes": [ - "numeric", - "money" - ], - "MappedType": "decimal", - "MappingFunction": "GetDecimal" - }, - { - "DatabaseTypes": [ - "text", - "character varying", - "character", - "citext", - "json", - "jsonb", - "xml", - "varchar" - ], - "MappedType": "string", - "MappingFunction": "GetString" - }, - { - "DatabaseTypes": [ - "uuid" - ], - "MappedType": "Guid", - "MappingFunction": "GetGuid" - }, - { - "DatabaseTypes": [ - "bytea" - ], - "MappedType": "byte[]", - "MappingFunction": "GetByteArray" - }, - { - "DatabaseTypes": [ - "timestamptz", - "date", - "timestamp" - ], - "MappedType": "DateTime", - "MappingFunction": "GetDateTime" - }, - { - "DatabaseTypes": [ - "interval" - ], - "MappedType": "TimeSpan", - "MappingFunction": "GetTimeSpan" - }, - { - "DatabaseTypes": [ - "ltree" - ], - "MappedType": "String", - "MappingFunction": "GetString" - } - ] + "OutputFolder": "./output", + "GenerateModels": true, + "GenerateProcessors": true, + "GenerateProcessorsForVoidReturns": false, + "ClearOutputFolder": false, + "DbContextTemplate": "./templates/dbcontext.gotmpl", + "ModelTemplate": "./templates/model.gotmpl", + "ProcessorTemplate": "./templates/processor.gotmpl", + "GeneratedFileExtension": ".cs", + "GeneratedFileCase": "camelcase", + "Generate": [ + { + "Schema": "public", + "AllFunctions": true, + "IgnoredFunctions": [ + "ignored" + ] + }, + { + "Schema": "test", + "AllFunctions": false, + "Functions": [ + "explicitly_included" + ] + } + ], + "Mappings": [ + { + "DatabaseTypes": [ + "boolean", + "bool" + ], + "MappedType": "bool", + "MappingFunction": "GetBoolean" + }, + { + "DatabaseTypes": [ + "smallint", + "int2" + ], + "MappedType": "short", + "MappingFunction": "GetInt16" + }, + { + "DatabaseTypes": [ + "integer", + "int4" + ], + "MappedType": "int", + "MappingFunction": "GetInt32" + }, + { + "DatabaseTypes": [ + "bigint", + "int8" + ], + "MappedType": "long", + "MappingFunction": "GetInt64" + }, + { + "DatabaseTypes": [ + "real", + "float4" + ], + "MappedType": "float", + "MappingFunction": "GetFloat" + }, + { + "DatabaseTypes": [ + "double precision", + "float8" + ], + "MappedType": "double", + "MappingFunction": "GetDouble" + }, + { + "DatabaseTypes": [ + "numeric", + "money" + ], + "MappedType": "decimal", + "MappingFunction": "GetDecimal" + }, + { + "DatabaseTypes": [ + "text", + "character varying", + "character", + "citext", + "json", + "jsonb", + "xml", + "varchar" + ], + "MappedType": "string", + "MappingFunction": "GetString" + }, + { + "DatabaseTypes": [ + "uuid" + ], + "MappedType": "Guid", + "MappingFunction": "GetGuid" + }, + { + "DatabaseTypes": [ + "bytea" + ], + "MappedType": "byte[]", + "MappingFunction": "GetByteArray" + }, + { + "DatabaseTypes": [ + "timestamptz", + "date", + "timestamp" + ], + "MappedType": "DateTime", + "MappingFunction": "GetDateTime" + }, + { + "DatabaseTypes": [ + "interval" + ], + "MappedType": "TimeSpan", + "MappingFunction": "GetTimeSpan" + }, + { + "DatabaseTypes": [ + "ltree" + ], + "MappedType": "String", + "MappingFunction": "GetString" + } + ] } diff --git a/testing/local.db-gen.json b/testing/local.db-gen.json index 368fbd9..4703f93 100755 --- a/testing/local.db-gen.json +++ b/testing/local.db-gen.json @@ -1,3 +1,3 @@ { "ConnectionString": "postgresql://postgres:Password3000!!@localhost:5432/db_gen" -} \ No newline at end of file +}