From 5213da6366f337801fc44ea0880ebf48e423b43f Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 08:37:21 +0200 Subject: [PATCH 01/12] Squashed commit of the following: commit 0ba627c47a425a9b464cc31a7066aaeadad9f972 Author: Erik Unger Date: Mon Jun 13 13:57:54 2022 +0200 adapt tests commit 38be8228e5fc3df7cd3b2f3a9c8621878fbdc543 Author: Erik Unger Date: Fri Jun 10 17:29:40 2022 +0200 moved insert to package db --- connection.go | 32 ++------ {impl => db}/insert.go | 88 +++++++++++++-------- db/reflectstruct.go | 133 ++++++++++++++++++++++++++++++++ errors.go | 24 ++---- examples/user_demo/user_demo.go | 10 ++- impl/connection.go | 44 ++++------- impl/reflectstruct.go | 102 +----------------------- impl/transaction.go | 56 ++++++-------- impl/upsert.go | 21 +++++ mockconn/connection.go | 30 +++---- mockconn/connection_test.go | 26 ++++--- pqconn/connection.go | 28 ++----- pqconn/transaction.go | 28 ++----- 13 files changed, 308 insertions(+), 314 deletions(-) rename {impl => db}/insert.go (68%) create mode 100644 db/reflectstruct.go diff --git a/connection.go b/connection.go index eff08b7..f73faf2 100644 --- a/connection.go +++ b/connection.go @@ -25,7 +25,7 @@ type Connection interface { WithContext(ctx context.Context) Connection // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldMapper. + // that will use the passed StructFieldNamer. WithStructFieldMapper(StructFieldMapper) Connection // StructFieldMapper used by methods of this Connection. @@ -50,6 +50,12 @@ type Connection interface { // column of the connection's database. ValidateColumnName(name string) error + // ArgFmt returns the format for SQL query arguments + ArgFmt() string + + // Err returns any current error of the connection + Err() error + // Now returns the result of the SQL now() // function for the current connection. // Useful for getting the timestamp of a @@ -59,30 +65,6 @@ type Connection interface { // Exec executes a query with optional args. Exec(query string, args ...any) error - // Insert a new row into table using the values. - Insert(table string, values Values) error - - // InsertUnique inserts a new row into table using the passed values - // or does nothing if the onConflict statement applies. - // Returns if a row was inserted. - InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) - - // InsertReturning inserts a new row into table using values - // and returns values from the inserted row listed in returning. - InsertReturning(table string, values Values, returning string) RowScanner - - // InsertStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // InsertUniqueStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - // Does nothing if the onConflict statement applies - // and returns if a row was inserted. - InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) - // Update table rows(s) with values using the where statement with passed in args starting at $1. Update(table string, values Values, where string, args ...any) error diff --git a/impl/insert.go b/db/insert.go similarity index 68% rename from impl/insert.go rename to db/insert.go index 8254b4f..05ba3d2 100644 --- a/impl/insert.go +++ b/db/insert.go @@ -1,19 +1,27 @@ -package impl +package db import ( + "context" "fmt" "reflect" "strings" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/impl" ) +type Values = sqldb.Values + +var WrapNonNilErrorWithQuery = impl.WrapNonNilErrorWithQuery + // Insert a new row into table using the values. -func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) error { +func Insert(ctx context.Context, table string, values Values) error { if len(values) == 0 { return fmt.Errorf("Insert into table %s: no values", table) } + conn := Conn(ctx) + argFmt := conn.ArgFmt() names, vals := values.Sorted() b := strings.Builder{} writeInsertQuery(&b, table, argFmt, names) @@ -27,7 +35,7 @@ func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) er // InsertUnique inserts a new row into table using the passed values // or does nothing if the onConflict statement applies. // Returns if a row was inserted. -func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Values, onConflict string) (inserted bool, err error) { +func InsertUnique(ctx context.Context, table string, values Values, onConflict string) (inserted bool, err error) { if len(values) == 0 { return false, fmt.Errorf("InsertUnique into table %s: no values", table) } @@ -36,6 +44,8 @@ func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Valu onConflict = onConflict[1 : len(onConflict)-1] } + conn := Conn(ctx) + argFmt := conn.ArgFmt() names, vals := values.Sorted() var query strings.Builder writeInsertQuery(&query, table, argFmt, names) @@ -50,11 +60,13 @@ func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Valu // InsertReturning inserts a new row into table using values // and returns values from the inserted row listed in returning. -func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.Values, returning string) sqldb.RowScanner { +func InsertReturning(ctx context.Context, table string, values Values, returning string) sqldb.RowScanner { if len(values) == 0 { return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) } + conn := Conn(ctx) + argFmt := conn.ArgFmt() names, vals := values.Sorted() var query strings.Builder writeInsertQuery(&query, table, argFmt, names) @@ -63,31 +75,15 @@ func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.V return conn.QueryRow(query.String(), vals...) } -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - fmt.Fprintf(w, argFmt, i+1) - } - w.WriteByte(')') -} - // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + conn := Conn(ctx) + argFmt := conn.ArgFmt() + mapper := conn.StructFieldMapper() + + table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) if err != nil { return err } @@ -106,8 +102,12 @@ func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld // Optional ColumnFilter can be passed to ignore mapped columns. // Does nothing if the onConflict statement applies // and returns if a row was inserted. -func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { + conn := Conn(ctx) + argFmt := conn.ArgFmt() + mapper := conn.StructFieldMapper() + + table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) if err != nil { return false, err } @@ -127,18 +127,38 @@ func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onCo return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals) } -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { +func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { + fmt.Fprintf(w, `INSERT INTO %s(`, table) + for i, name := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteByte('"') + w.WriteString(name) + w.WriteByte('"') + } + w.WriteString(`) VALUES(`) + for i := range names { + if i > 0 { + w.WriteByte(',') + } + fmt.Fprintf(w, argFmt, i+1) + } + w.WriteByte(')') +} + +func insertStructValues(rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (table string, columns []string, vals []any, err error) { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() } switch { case v.Kind() == reflect.Ptr && v.IsNil(): - return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) + return "", nil, nil, fmt.Errorf("can't insert nil") case v.Kind() != reflect.Struct: - return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) + return "", nil, nil, fmt.Errorf("expected struct but got %T", rowStruct) } - columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - return columns, vals, nil + table, columns, _, vals, err = ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + return table, columns, vals, err } diff --git a/db/reflectstruct.go b/db/reflectstruct.go new file mode 100644 index 0000000..88f438f --- /dev/null +++ b/db/reflectstruct.go @@ -0,0 +1,133 @@ +package db + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "golang.org/x/exp/slices" + + "github.com/domonda/go-sqldb" +) + +func ReflectStructValues(structVal reflect.Value, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { + structType := structVal.Type() + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldTable, column, flags, use := mapper.MapStructField(fieldType) + if !use { + continue + } + fieldValue := structVal.Field(i) + + if column == "" { + // Embedded struct field + fieldTable, columnsEmbed, pkColsEmbed, valuesEmbed, err := ReflectStructValues(fieldValue, mapper, ignoreColumns) + if err != nil { + return "", nil, nil, nil, err + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) + } + table = fieldTable + } + for _, pkCol := range pkColsEmbed { + pkCols = append(pkCols, pkCol+len(columns)) + } + columns = append(columns, columnsEmbed...) + values = append(values, valuesEmbed...) + continue + } + + if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { + continue + } + + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) + } + table = fieldTable + } + if flags.PrimaryKey() { + pkCols = append(pkCols, len(columns)) + } + columns = append(columns, column) + values = append(values, fieldValue.Interface()) + } + return table, columns, pkCols, values, nil +} + +func ReflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { + if len(columns) == 0 { + return nil, errors.New("no columns") + } + pointers = make([]any, len(columns)) + err = reflectStructColumnPointers(structVal, mapper, columns, pointers) + if err != nil { + return nil, err + } + for _, ptr := range pointers { + if ptr != nil { + continue + } + nilCols := new(strings.Builder) + for i, ptr := range pointers { + if ptr != nil { + continue + } + if nilCols.Len() > 0 { + nilCols.WriteString(", ") + } + fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) + } + return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) + } + return pointers, nil +} + +func reflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string, pointers []any) error { + var ( + structType = structVal.Type() + ) + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + _, column, _, use := mapper.MapStructField(fieldType) + if !use { + continue + } + fieldValue := structVal.Field(i) + + if column == "" { + // Embedded struct field + err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers) + if err != nil { + return err + } + continue + } + + colIndex := slices.Index(columns, column) + if colIndex == -1 { + continue + } + + if pointers[colIndex] != nil { + return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) + } + + pointers[colIndex] = fieldValue.Addr().Interface() + } + return nil +} + +func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + for _, filter := range filters { + if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { + return true + } + } + return false +} diff --git a/errors.go b/errors.go index 124d0f3..d0ed89f 100644 --- a/errors.go +++ b/errors.go @@ -104,34 +104,22 @@ func (e connectionWithError) ValidateColumnName(name string) error { return e.err } -func (e connectionWithError) Now() (time.Time, error) { - return time.Time{}, e.err +func (e connectionWithError) ArgFmt() string { + return "" } -func (e connectionWithError) Exec(query string, args ...any) error { +func (e connectionWithError) Err() error { return e.err } -func (e connectionWithError) Insert(table string, values Values) error { - return e.err -} - -func (e connectionWithError) InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) { - return false, e.err -} - -func (e connectionWithError) InsertReturning(table string, values Values, returning string) RowScanner { - return RowScannerWithError(e.err) +func (e connectionWithError) Now() (time.Time, error) { + return time.Time{}, e.err } -func (e connectionWithError) InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { +func (e connectionWithError) Exec(query string, args ...any) error { return e.err } -func (e connectionWithError) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) { - return false, e.err -} - func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { return e.err } diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index dacf487..2b32e21 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -16,7 +16,7 @@ import ( ) type User struct { - ID uu.ID `db:"id,pk,default"` + ID uu.ID `db:"id,pk=public.user,default"` Email email.NullableAddress `db:"email"` Title nullable.NonEmptyString `db:"title"` @@ -87,18 +87,20 @@ func main() { panic(err) } + ctx := context.Background() + newUser := &User{ /* ... */ } - err = conn.InsertStruct("public.user", newUser) + err = db.InsertStruct(ctx, newUser) if err != nil { panic(err) } - err = conn.InsertStruct("public.user", newUser, sqldb.IgnoreNullOrZeroDefault) + err = db.InsertStruct(ctx, newUser, sqldb.IgnoreNullOrZeroDefault) if err != nil { panic(err) } - err = conn.Insert("public.user", sqldb.Values{ + err = db.Insert(ctx, "public.user", sqldb.Values{ "name": "Erik Unger", "email": "erik@domonda.com", }) diff --git a/impl/connection.go b/impl/connection.go index e50f3b2..c2cd710 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -18,7 +18,7 @@ func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateC ctx: ctx, db: db, config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, + structFieldMapper: sqldb.DefaultStructFieldMapping, argFmt: argFmt, validateColumnName: validateColumnName, } @@ -28,7 +28,7 @@ type connection struct { ctx context.Context db *sql.DB config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper + structFieldMapper sqldb.StructFieldMapper argFmt string validateColumnName func(string) error } @@ -51,12 +51,12 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = namer + c.structFieldMapper = namer return c } func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer + return conn.structFieldMapper } func (conn *connection) Ping(timeout time.Duration) error { @@ -81,6 +81,14 @@ func (conn *connection) ValidateColumnName(name string) error { return conn.validateColumnName(name) } +func (conn *connection) ArgFmt() string { + return conn.argFmt +} + +func (conn *connection) Err() error { + return nil +} + func (conn *connection) Now() (time.Time, error) { return Now(conn) } @@ -90,26 +98,6 @@ func (conn *connection) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) } -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return Insert(conn, table, conn.argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return InsertUnique(conn, table, conn.argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return InsertReturning(conn, table, conn.argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { return Update(conn, table, values, where, conn.argFmt, args) } @@ -123,11 +111,11 @@ func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, r } func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return UpdateStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) } func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return UpsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { @@ -136,7 +124,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) return sqldb.RowScannerWithError(err) } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) + return NewRowScanner(rows, conn.structFieldMapper, query, conn.argFmt, args) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -145,7 +133,7 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) return sqldb.RowsScannerWithError(err) } - return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) + return NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn.argFmt, args) } func (conn *connection) IsTransaction() bool { diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go index c92c230..8860f69 100644 --- a/impl/reflectstruct.go +++ b/impl/reflectstruct.go @@ -1,117 +1,23 @@ package impl import ( - "errors" - "fmt" "reflect" - "strings" - - "golang.org/x/exp/slices" "github.com/domonda/go-sqldb" ) func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { - for i := 0; i < structVal.NumField(); i++ { - fieldType := structVal.Type().Field(i) - _, column, flags, use := namer.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, namer, ignoreColumns) - for _, pkCol := range pkColsEmbed { - pkCols = append(pkCols, pkCol+len(columns)) - } - columns = append(columns, columnsEmbed...) - values = append(values, valuesEmbed...) - continue - } - - if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { - continue - } - - if flags.PrimaryKey() { - pkCols = append(pkCols, len(columns)) - } - columns = append(columns, column) - values = append(values, fieldValue.Interface()) - } - return columns, pkCols, values + panic("TODO remove") } func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { - if len(columns) == 0 { - return nil, errors.New("no columns") - } - pointers = make([]any, len(columns)) - err = reflectStructColumnPointers(structVal, namer, columns, pointers) - if err != nil { - return nil, err - } - for _, ptr := range pointers { - if ptr != nil { - continue - } - nilCols := new(strings.Builder) - for i, ptr := range pointers { - if ptr != nil { - continue - } - if nilCols.Len() > 0 { - nilCols.WriteString(", ") - } - fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) - } - return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) - } - return pointers, nil + panic("TODO remove") } func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { - var ( - structType = structVal.Type() - ) - for i := 0; i < structType.NumField(); i++ { - fieldType := structType.Field(i) - _, column, _, use := namer.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - err := reflectStructColumnPointers(fieldValue, namer, columns, pointers) - if err != nil { - return err - } - continue - } - - colIndex := slices.Index(columns, column) - if colIndex == -1 { - continue - } - - if pointers[colIndex] != nil { - return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) - } - - pointers[colIndex] = fieldValue.Addr().Interface() - } - return nil + panic("TODO remove") } func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, filter := range filters { - if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { - return true - } - } - return false + panic("TODO remove") } diff --git a/impl/transaction.go b/impl/transaction.go index 1363fe9..db19ad0 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -12,18 +12,18 @@ import ( type transaction struct { // The parent non-transaction connection is needed // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldNamer sqldb.StructFieldMapper + parent *connection + tx *sql.Tx + opts *sql.TxOptions + structFieldMapper sqldb.StructFieldMapper } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldNamer: parent.structFieldNamer, + parent: parent, + tx: tx, + opts: opts, + structFieldMapper: parent.structFieldMapper, } } @@ -45,12 +45,12 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = namer + c.structFieldMapper = namer return c } func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer + return conn.structFieldMapper } func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } @@ -61,6 +61,14 @@ func (conn *transaction) ValidateColumnName(name string) error { return conn.parent.validateColumnName(name) } +func (conn *transaction) ArgFmt() string { + return conn.parent.argFmt +} + +func (conn *transaction) Err() error { + return conn.parent.Err() +} + func (conn *transaction) Now() (time.Time, error) { return Now(conn) } @@ -70,26 +78,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } -func (conn *transaction) Insert(table string, columValues sqldb.Values) error { - return Insert(conn, table, conn.parent.argFmt, columValues) -} - -func (conn *transaction) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return InsertUnique(conn, table, conn.parent.argFmt, values, onConflict) -} - -func (conn *transaction) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return InsertReturning(conn, table, conn.parent.argFmt, values, returning) -} - -func (conn *transaction) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { return Update(conn, table, values, where, conn.parent.argFmt, args) } @@ -103,11 +91,11 @@ func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, } func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) + return UpdateStruct(conn, table, rowStruct, conn.structFieldMapper, conn.parent.argFmt, ignoreColumns) } func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) + return UpsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.parent.argFmt, ignoreColumns) } func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { @@ -116,7 +104,7 @@ func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) return sqldb.RowScannerWithError(err) } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.parent.argFmt, args) + return NewRowScanner(rows, conn.structFieldMapper, query, conn.parent.argFmt, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -125,7 +113,7 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) return sqldb.RowsScannerWithError(err) } - return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) + return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn.parent.argFmt, args) } func (conn *transaction) IsTransaction() bool { diff --git a/impl/upsert.go b/impl/upsert.go index 7aa92f9..9e3058e 100644 --- a/impl/upsert.go +++ b/impl/upsert.go @@ -61,3 +61,24 @@ func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld return WrapNonNilErrorWithQuery(err, query, argFmt, vals) } + +// TODO replace +func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { + fmt.Fprintf(w, `INSERT INTO %s(`, table) + for i, name := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteByte('"') + w.WriteString(name) + w.WriteByte('"') + } + w.WriteString(`) VALUES(`) + for i := range names { + if i > 0 { + w.WriteByte(',') + } + fmt.Fprintf(w, argFmt, i+1) + } + w.WriteByte(')') +} diff --git a/mockconn/connection.go b/mockconn/connection.go index de4457e..a74da12 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -68,6 +68,10 @@ func (conn *connection) Stats() sql.DBStats { return sql.DBStats{} } +func (conn *connection) Ping(time.Duration) error { + return nil +} + func (conn *connection) Config() *sqldb.Config { return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"} } @@ -76,7 +80,11 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *connection) Ping(time.Duration) error { +func (conn *connection) ArgFmt() string { + return conn.argFmt +} + +func (conn *connection) Err() error { return nil } @@ -91,26 +99,6 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, conn.argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, conn.argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, conn.argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { return impl.Update(conn, table, values, where, conn.argFmt, args) } diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index e99e034..d8387a0 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/db" "github.com/domonda/go-types/uu" ) @@ -18,7 +19,7 @@ type embed struct { } type testRow struct { - ID uu.ID `db:"id,pk"` + ID uu.ID `db:"id,pk=public.table"` Int int `db:"int"` embed Str string `db:"str"` @@ -32,10 +33,11 @@ type testRow struct { } func TestInsertQuery(t *testing.T) { + context.Background() naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} queryOutput := bytes.NewBuffer(nil) rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) - conn := New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -51,13 +53,13 @@ func TestInsertQuery(t *testing.T) { } expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` - err := conn.Insert("public.table", values) + err := db.Insert(ctx, "public.table", values) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err := conn.InsertUnique("public.table", values, "id") + inserted, err := db.InsertUnique(ctx, "public.table", values, "id") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) @@ -73,24 +75,24 @@ func TestInsertStructQuery(t *testing.T) { Default: "default", UntaggedNameFunc: sqldb.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` - err := conn.InsertStruct("public.table", row) + err := db.InsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3)` - err = conn.InsertStruct("public.table", row, sqldb.OnlyColumns("id", "untagged_field", "bools")) + err = db.InsertStruct(ctx, row, sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6)` - err = conn.InsertStruct("public.table", row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) + err = db.InsertStruct(ctx, row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -106,26 +108,26 @@ func TestInsertUniqueStructQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) - conn := New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err := conn.InsertUniqueStruct("public.table", row, "(id)") + inserted, err := db.InsertUniqueStruct(ctx, row, "(id)") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` - inserted, err = conn.InsertUniqueStruct("public.table", row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) + inserted, err = db.InsertUniqueStruct(ctx, row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err = conn.InsertUniqueStruct("public.table", row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) + inserted, err = db.InsertUniqueStruct(ctx, row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) diff --git a/pqconn/connection.go b/pqconn/connection.go index 65c01a8..159f7e1 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -102,6 +102,14 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } +func (*connection) ArgFmt() string { + return argFmt +} + +func (conn *connection) Err() error { + return conn.config.Err +} + func (conn *connection) Now() (time.Time, error) { return impl.Now(conn) } @@ -111,26 +119,6 @@ func (conn *connection) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { return impl.Update(conn, table, values, where, argFmt, args) } diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 4019f5e..736a545 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -61,6 +61,14 @@ func (conn *transaction) ValidateColumnName(name string) error { return validateColumnName(name) } +func (*transaction) ArgFmt() string { + return argFmt +} + +func (conn *transaction) Err() error { + return conn.parent.config.Err +} + func (conn *transaction) Now() (time.Time, error) { return impl.Now(conn) } @@ -70,26 +78,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } -func (conn *transaction) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, argFmt, columValues) -} - -func (conn *transaction) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, argFmt, values, onConflict) -} - -func (conn *transaction) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) -} - -func (conn *transaction) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { return impl.Update(conn, table, values, where, argFmt, args) } From c932c8a82413a1a3ff33e57050d620eaf3b62f90 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 09:41:53 +0200 Subject: [PATCH 02/12] move update and upsert to pkg db --- connection.go | 28 ----------- db/config.go | 4 +- db/conn.go | 8 ++-- db/query.go | 38 +++++++++++++++ db/reflectstruct.go | 14 ++++++ db/transaction.go | 8 ++-- db/transaction_test.go | 4 +- {impl => db}/update.go | 51 +++++++++++--------- db/upsert.go | 67 ++++++++++++++++++++++++++ examples/user_demo/user_demo.go | 2 +- impl/connection.go | 28 +++-------- impl/now.go | 15 ------ impl/transaction.go | 28 +++-------- impl/upsert.go | 84 --------------------------------- mockconn/connection.go | 21 --------- mockconn/connection_test.go | 34 ++++++------- pqconn/connection.go | 28 +++-------- pqconn/transaction.go | 28 +++-------- 18 files changed, 202 insertions(+), 288 deletions(-) create mode 100644 db/query.go rename {impl => db}/update.go (63%) create mode 100644 db/upsert.go delete mode 100644 impl/now.go delete mode 100644 impl/upsert.go diff --git a/connection.go b/connection.go index f73faf2..f22ad49 100644 --- a/connection.go +++ b/connection.go @@ -65,34 +65,6 @@ type Connection interface { // Exec executes a query with optional args. Exec(query string, args ...any) error - // Update table rows(s) with values using the where statement with passed in args starting at $1. - Update(table string, values Values, where string, args ...any) error - - // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 - // and returning a single row with the columns specified in returning argument. - UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner - - // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 - // and returning multiple rows with the columns specified in returning argument. - UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner - - // UpdateStruct updates a row in a table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // UpsertStruct upserts a row to table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - // If inserting conflicts on the primary key column(s), then an update is performed. - UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - // QueryRow queries a single row and returns a RowScanner for the results. QueryRow(query string, args ...any) RowScanner diff --git a/db/config.go b/db/config.go index 3e63375..bd6c1b4 100644 --- a/db/config.go +++ b/db/config.go @@ -14,10 +14,10 @@ var ( ) var ( - conn = sqldb.ConnectionWithError( + globalConn = sqldb.ConnectionWithError( context.Background(), errors.New("database connection not initialized"), ) - connCtxKey int + globalConnCtxKey int serializedTransactionCtxKey int ) diff --git a/db/conn.go b/db/conn.go index 241cfb3..c43dbf8 100644 --- a/db/conn.go +++ b/db/conn.go @@ -14,7 +14,7 @@ func SetConn(c sqldb.Connection) { if c == nil { panic("must not set nil sqldb.Connection") } - conn = c + globalConn = c } // Conn returns a non nil sqldb.Connection from ctx @@ -22,7 +22,7 @@ func SetConn(c sqldb.Connection) { // The returned connection will use the passed context. // See sqldb.Connection.WithContext func Conn(ctx context.Context) sqldb.Connection { - return ConnDefault(ctx, conn) + return ConnDefault(ctx, globalConn) } // ConnDefault returns a non nil sqldb.Connection from ctx @@ -30,7 +30,7 @@ func Conn(ctx context.Context) sqldb.Connection { // The returned connection will use the passed context. // See sqldb.Connection.WithContext func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { - c, _ := ctx.Value(&connCtxKey).(sqldb.Connection) + c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) if c == nil { c = defaultConn } @@ -45,7 +45,7 @@ func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connec // Passing a nil connection causes Conn(ctx) // to return the global connection set with SetConn. func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context { - return context.WithValue(ctx, &connCtxKey, conn) + return context.WithValue(ctx, &globalConnCtxKey, conn) } // ContextWithoutCancel returns a new context that inherits diff --git a/db/query.go b/db/query.go new file mode 100644 index 0000000..08bbcbd --- /dev/null +++ b/db/query.go @@ -0,0 +1,38 @@ +package db + +import ( + "context" + "time" + + "github.com/domonda/go-sqldb" +) + +// Now returns the result of the SQL now() +// function for the current connection. +// Useful for getting the timestamp of a +// SQL transaction for use in Go code. +func Now(ctx context.Context) (time.Time, error) { + return Conn(ctx).Now() +} + +// Exec executes a query with optional args. +func Exec(ctx context.Context, query string, args ...any) error { + return Conn(ctx).Exec(query, args...) +} + +// QueryRow queries a single row and returns a RowScanner for the results. +func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner { + return Conn(ctx).QueryRow(query, args...) +} + +// QueryRows queries multiple rows and returns a RowsScanner for the results. +func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner { + return Conn(ctx).QueryRows(query, args...) +} + +// QueryValue queries a single value of type T. +func QueryValue[T any](ctx context.Context, query string, args ...any) (T, error) { + var val T + err := Conn(ctx).QueryRow(query, args...).Scan(&val) + return val, err +} diff --git a/db/reflectstruct.go b/db/reflectstruct.go index 88f438f..7df5bb2 100644 --- a/db/reflectstruct.go +++ b/db/reflectstruct.go @@ -131,3 +131,17 @@ func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFl } return false } + +func derefStruct(rowStruct any) (reflect.Value, error) { + v := reflect.ValueOf(rowStruct) + for v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + switch { + case v.Kind() == reflect.Ptr && v.IsNil(): + return reflect.Value{}, errors.New("can't use nil pointer") + case v.Kind() != reflect.Struct: + return reflect.Value{}, fmt.Errorf("expected struct but got %T", rowStruct) + } + return v, nil +} diff --git a/db/transaction.go b/db/transaction.go index 58ea7fd..fb85419 100644 --- a/db/transaction.go +++ b/db/transaction.go @@ -51,7 +51,7 @@ func DebugNoTransaction(ctx context.Context, nonTxFunc func(context.Context) err // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error) error { return sqldb.IsolatedTransaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -64,7 +64,7 @@ func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func Transaction(ctx context.Context, txFunc func(context.Context) error) error { return sqldb.Transaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -135,7 +135,7 @@ func SerializedTransaction(ctx context.Context, txFunc func(context.Context) err // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(context.Context) error) error { return sqldb.Transaction(Conn(ctx), opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -149,7 +149,7 @@ func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(conte func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error) error { opts := sql.TxOptions{ReadOnly: true} return sqldb.Transaction(Conn(ctx), &opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } diff --git a/db/transaction_test.go b/db/transaction_test.go index 16adfda..a6e16ac 100644 --- a/db/transaction_test.go +++ b/db/transaction_test.go @@ -10,7 +10,7 @@ import ( ) func TestSerializedTransaction(t *testing.T) { - conn = mockconn.New(context.Background(), os.Stdout, nil) + globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectSerialized := func(ctx context.Context) error { if !Conn(ctx).IsTransaction() { @@ -64,7 +64,7 @@ func TestSerializedTransaction(t *testing.T) { } func TestTransaction(t *testing.T) { - conn = mockconn.New(context.Background(), os.Stdout, nil) + globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectNonSerialized := func(ctx context.Context) error { if !Conn(ctx).IsTransaction() { diff --git a/impl/update.go b/db/update.go similarity index 63% rename from impl/update.go rename to db/update.go index 2d5976f..10e86e9 100644 --- a/impl/update.go +++ b/db/update.go @@ -1,8 +1,8 @@ -package impl +package db import ( + "context" "fmt" - "reflect" "strings" sqldb "github.com/domonda/go-sqldb" @@ -10,41 +10,45 @@ import ( ) // Update table rows(s) with values using the where statement with passed in args starting at $1. -func Update(conn sqldb.Connection, table string, values sqldb.Values, where, argFmt string, args []any) error { +func Update(ctx context.Context, table string, values sqldb.Values, where string, args ...any) error { if len(values) == 0 { return fmt.Errorf("Update table %s: no values passed", table) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + argFmt := conn.ArgFmt() + query, vals := buildUpdateQuery(table, values, where, argFmt, args) err := conn.Exec(query, vals...) return WrapNonNilErrorWithQuery(err, query, argFmt, vals) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 // and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { if len(values) == 0 { return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + query, vals := buildUpdateQuery(table, values, where, conn.ArgFmt(), args) query += " RETURNING " + returning return conn.QueryRow(query, vals...) } // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 // and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { if len(values) == 0 { return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + query, vals := buildUpdateQuery(table, values, where, conn.ArgFmt(), args) query += " RETURNING " + returning return conn.QueryRows(query, vals...) } -func buildUpdateQuery(table string, values sqldb.Values, where string, args []any) (string, []any) { +func buildUpdateQuery(table string, values sqldb.Values, where, argFmt string, args []any) (string, []any) { names, vals := values.Sorted() var query strings.Builder @@ -53,7 +57,7 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an if i > 0 { query.WriteByte(',') } - fmt.Fprintf(&query, `"%s"=$%d`, names[i], 1+len(args)+i) + fmt.Fprintf(&query, `"%s"=%s`, names[i], fmt.Sprintf(argFmt, 1+len(args)+i)) } fmt.Fprintf(&query, ` WHERE %s`, where) @@ -65,19 +69,22 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an // Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpdateStruct of table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpdateStruct of table %s: expected struct but got %T", table, rowStruct) +func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + v, err := derefStruct(rowStruct) + if err != nil { + return err } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + conn := Conn(ctx) + argFmt := conn.ArgFmt() + mapper := conn.StructFieldMapper() + table, columns, pkCols, vals, err := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + if err != nil { + return err + } + if table == "" { + return fmt.Errorf("UpdateStruct: %s has no table name", v.Type()) + } if len(pkCols) == 0 { return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) } @@ -107,7 +114,7 @@ func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld query := b.String() - err := conn.Exec(query, vals...) + err = conn.Exec(query, vals...) return WrapNonNilErrorWithQuery(err, query, argFmt, vals) } diff --git a/db/upsert.go b/db/upsert.go new file mode 100644 index 0000000..3c03a8e --- /dev/null +++ b/db/upsert.go @@ -0,0 +1,67 @@ +package db + +import ( + "context" + "fmt" + "strings" + + "github.com/domonda/go-sqldb" + "golang.org/x/exp/slices" +) + +// UpsertStruct upserts a row to table using the exported fields +// of rowStruct which have a `db` tag that is not "-". +// If restrictToColumns are provided, then only struct fields with a `db` tag +// matching any of the passed column names will be used. +// The struct must have at least one field with a `db` tag value having a ",pk" suffix +// to mark primary key column(s). +// If inserting conflicts on the primary key column(s), then an update is performed. +func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + v, err := derefStruct(rowStruct) + if err != nil { + return err + } + + conn := Conn(ctx) + argFmt := conn.ArgFmt() + mapper := conn.StructFieldMapper() + table, columns, pkCols, vals, err := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + if err != nil { + return err + } + if table == "" { + return fmt.Errorf("UpsertStruct: %s has no table name", v.Type()) + } + if len(pkCols) == 0 { + return fmt.Errorf("UpsertStruct: %s has no mapped primary key field", v.Type()) + } + + var b strings.Builder + writeInsertQuery(&b, table, argFmt, columns) + b.WriteString(` ON CONFLICT(`) + for i, pkCol := range pkCols { + if i > 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, `"%s"`, columns[pkCol]) + } + + b.WriteString(`) DO UPDATE SET `) + first := true + for i := range columns { + if slices.Contains(pkCols, i) { + continue + } + if first { + first = false + } else { + b.WriteByte(',') + } + fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) + } + query := b.String() + + err = conn.Exec(query, vals...) + + return WrapNonNilErrorWithQuery(err, query, argFmt, vals) +} diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index 2b32e21..3a76761 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -108,7 +108,7 @@ func main() { panic(err) } - err = conn.UpsertStruct("public.user", newUser, sqldb.IgnoreColumns("created_at")) + err = db.UpsertStruct(ctx, newUser, sqldb.IgnoreColumns("created_at")) if err != nil { panic(err) } diff --git a/impl/connection.go b/impl/connection.go index c2cd710..d330fa7 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -89,8 +89,12 @@ func (conn *connection) Err() error { return nil } -func (conn *connection) Now() (time.Time, error) { - return Now(conn) +func (conn *connection) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } func (conn *connection) Exec(query string, args ...any) error { @@ -98,26 +102,6 @@ func (conn *connection) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { diff --git a/impl/now.go b/impl/now.go deleted file mode 100644 index 4a1bd2f..0000000 --- a/impl/now.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "time" - - "github.com/domonda/go-sqldb" -) - -func Now(conn sqldb.Connection) (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} diff --git a/impl/transaction.go b/impl/transaction.go index db19ad0..cc2be4e 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -69,8 +69,12 @@ func (conn *transaction) Err() error { return conn.parent.Err() } -func (conn *transaction) Now() (time.Time, error) { - return Now(conn) +func (conn *transaction) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } func (conn *transaction) Exec(query string, args ...any) error { @@ -78,26 +82,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.parent.argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldMapper, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.parent.argFmt, ignoreColumns) -} - func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { diff --git a/impl/upsert.go b/impl/upsert.go deleted file mode 100644 index 9e3058e..0000000 --- a/impl/upsert.go +++ /dev/null @@ -1,84 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - "strings" - - sqldb "github.com/domonda/go-sqldb" - "golang.org/x/exp/slices" -) - -// UpsertStruct upserts a row to table using the exported fields -// of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. -// If restrictToColumns are provided, then only struct fields with a `db` tag -// matching any of the passed column names will be used. -// If inserting conflicts on pkColumn, then an update of the existing row is performed. -func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpsertStruct to table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) - } - - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - if len(pkCols) == 0 { - return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) - } - - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - b.WriteString(` ON CONFLICT(`) - for i, pkCol := range pkCols { - if i > 0 { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"`, columns[pkCol]) - } - - b.WriteString(`) DO UPDATE SET `) - first := true - for i := range columns { - if slices.Contains(pkCols, i) { - continue - } - if first { - first = false - } else { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) - } - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} - -// TODO replace -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - fmt.Fprintf(w, argFmt, i+1) - } - w.WriteByte(')') -} diff --git a/mockconn/connection.go b/mockconn/connection.go index a74da12..694322d 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -8,7 +8,6 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) var DefaultArgFmt = "$%d" @@ -99,26 +98,6 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { if conn.ctx.Err() != nil { return sqldb.RowScannerWithError(conn.ctx.Err()) diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index d8387a0..cd646e2 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -136,7 +136,7 @@ func TestInsertUniqueStructQuery(t *testing.T) { func TestUpdateQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -151,13 +151,13 @@ func TestUpdateQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1` - err := conn.Update("public.table", values, "id = $1", 1) + err := db.Update(ctx, "public.table", values, "id = $1", 1) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$3,"bools"=$4,"created_at"=$5,"int"=$6,"nil_ptr"=$7,"str"=$8,"str_ptr"=$9,"untagged_field"=$10 WHERE a = $1 AND b = $2` - err = conn.Update("public.table", values, "a = $1 AND b = $2", 1, 2) + err = db.Update(ctx, "public.table", values, "a = $1 AND b = $2", 1, 2) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -165,7 +165,7 @@ func TestUpdateQuery(t *testing.T) { func TestUpdateReturningQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -180,13 +180,13 @@ func TestUpdateReturningQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING *` - err := conn.UpdateReturningRow("public.table", values, "*", "id = $1", 1).Scan() + err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = $1", 1).Scan() assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING created_at,untagged_field` - err = conn.UpdateReturningRows("public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) + err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -201,24 +201,24 @@ func TestUpdateStructQuery(t *testing.T) { Default: "default", UntaggedNameFunc: sqldb.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `UPDATE public.table SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9 WHERE "id"=$1` - err := conn.UpdateStruct("public.table", row) + err := db.UpdateStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"str"=$3,"created_at"=$4 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) + err = db.UpdateStruct(ctx, row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "int"=$2,"bool"=$3,"str_ptr"=$4,"nil_ptr"=$5,"created_at"=$6 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) + err = db.UpdateStruct(ctx, row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -233,19 +233,19 @@ func TestUpsertStructQuery(t *testing.T) { Default: "default", UntaggedNameFunc: sqldb.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + ` ON CONFLICT("id") DO UPDATE SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9` - err := conn.UpsertStruct("public.table", row) + err := db.UpsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } type multiPrimaryKeyRow struct { - FirstID string `db:"first_id,pk"` + FirstID string `db:"first_id,pk=public.multi_pk"` SecondID string `db:"second_id,pk"` ThirdID string `db:"third_id,pk"` @@ -262,12 +262,12 @@ func TestUpsertStructMultiPKQuery(t *testing.T) { Default: "default", UntaggedNameFunc: sqldb.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(multiPrimaryKeyRow) expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES($1,$2,$3,$4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=$4` - err := conn.UpsertStruct("public.multi_pk", row) + err := db.UpsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -282,12 +282,12 @@ func TestUpdateStructMultiPKQuery(t *testing.T) { Default: "default", UntaggedNameFunc: sqldb.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(multiPrimaryKeyRow) expected := `UPDATE public.multi_pk SET "created_at"=$4 WHERE "first_id"=$1 AND "second_id"=$2 AND "third_id"=$3` - err := conn.UpdateStruct("public.multi_pk", row) + err := db.UpdateStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } diff --git a/pqconn/connection.go b/pqconn/connection.go index 159f7e1..4502412 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -110,8 +110,12 @@ func (conn *connection) Err() error { return conn.config.Err } -func (conn *connection) Now() (time.Time, error) { - return impl.Now(conn) +func (conn *connection) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } func (conn *connection) Exec(query string, args ...any) error { @@ -119,26 +123,6 @@ func (conn *connection) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 736a545..c9a4f29 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -69,8 +69,12 @@ func (conn *transaction) Err() error { return conn.parent.config.Err } -func (conn *transaction) Now() (time.Time, error) { - return impl.Now(conn) +func (conn *transaction) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } func (conn *transaction) Exec(query string, args ...any) error { @@ -78,26 +82,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { From aad5ef9b64e1323df6672e8919b6f577191e3cef Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 09:43:11 +0200 Subject: [PATCH 03/12] join source files --- db/query.go | 82 +++++++++++++++++++++++++++++++++++++++++++ db/querystruct.go | 89 ----------------------------------------------- 2 files changed, 82 insertions(+), 89 deletions(-) delete mode 100644 db/querystruct.go diff --git a/db/query.go b/db/query.go index 08bbcbd..a0ad4ec 100644 --- a/db/query.go +++ b/db/query.go @@ -2,6 +2,9 @@ package db import ( "context" + "errors" + "fmt" + "reflect" "time" "github.com/domonda/go-sqldb" @@ -36,3 +39,82 @@ func QueryValue[T any](ctx context.Context, query string, args ...any) (T, error err := Conn(ctx).QueryRow(query, args...).Scan(&val) return val, err } + +// QueryStruct uses the passed pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValues and a table name. +func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) { + if len(pkValues) == 0 { + return nil, errors.New("missing primary key values") + } + t := reflect.TypeOf(row).Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct template type instead of %s", t) + } + conn := Conn(ctx) + table, pkColumns, err := pkColumnsOfStruct(conn, t) + if err != nil { + return nil, err + } + if len(pkColumns) != len(pkValues) { + return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) + } + query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) + for i := 1; i < len(pkColumns); i++ { + query += fmt.Sprintf(` AND "%s" = $%d`, pkColumns[i], i+1) + } + err = conn.QueryRow(query, pkValues...).ScanStruct(&row) + if err != nil { + return nil, err + } + return row, nil +} + +// QueryStructOrNil uses the passed pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValues and a table name. +// Returns nil as row and error if no row could be found with the +// passed pkValues. +func QueryStructOrNil[S any](ctx context.Context, pkValues ...any) (row *S, err error) { + row, err = QueryStruct[S](ctx, pkValues...) + return row, ReplaceErrNoRows(err, nil) +} + +func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { + mapper := conn.StructFieldMapper() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldTable, column, flags, ok := mapper.MapStructField(field) + if !ok { + continue + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + + if column == "" { + fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) + if err != nil { + return "", nil, err + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + columns = append(columns, columnsEmbed...) + } else if flags.PrimaryKey() { + if err = conn.ValidateColumnName(column); err != nil { + return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) + } + columns = append(columns, column) + } + } + return table, columns, nil +} diff --git a/db/querystruct.go b/db/querystruct.go deleted file mode 100644 index 1c20342..0000000 --- a/db/querystruct.go +++ /dev/null @@ -1,89 +0,0 @@ -package db - -import ( - "context" - "errors" - "fmt" - "reflect" - - "github.com/domonda/go-sqldb" -) - -// QueryStruct uses the passed pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. -func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - if len(pkValues) == 0 { - return nil, errors.New("missing primary key values") - } - t := reflect.TypeOf(row).Elem() - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("expected struct template type instead of %s", t) - } - conn := Conn(ctx) - table, pkColumns, err := pkColumnsOfStruct(conn, t) - if err != nil { - return nil, err - } - if len(pkColumns) != len(pkValues) { - return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) - } - query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) - for i := 1; i < len(pkColumns); i++ { - query += fmt.Sprintf(` AND "%s" = $%d`, pkColumns[i], i+1) - } - err = conn.QueryRow(query, pkValues...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil -} - -// QueryStructOrNil uses the passed pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. -// Returns nil as row and error if no row could be found with the -// passed pkValues. -func QueryStructOrNil[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - row, err = QueryStruct[S](ctx, pkValues...) - return row, ReplaceErrNoRows(err, nil) -} - -func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { - mapper := conn.StructFieldMapper() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldTable, column, flags, ok := mapper.MapStructField(field) - if !ok { - continue - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - - if column == "" { - fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) - if err != nil { - return "", nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - columns = append(columns, columnsEmbed...) - } else if flags.PrimaryKey() { - if err = conn.ValidateColumnName(column); err != nil { - return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) - } - columns = append(columns, column) - } - } - return table, columns, nil -} From 07bea6c5717d7ed6d9f4c9103229cd65ed04e145 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 12:13:09 +0200 Subject: [PATCH 04/12] remove package impl --- columnfilter.go | 32 +-- config.go | 7 + connection.go | 8 +- connectionwitherror.go | 205 +++++++++++++++++ db/insert.go | 12 +- db/update.go | 21 +- db/upsert.go | 5 +- errors.go | 217 +++--------------- impl/errors_test.go => errors_test.go | 2 +- examples/user_demo/user_demo.go | 5 +- impl/foreachrow.go => foreachrow.go | 22 +- impl/foreachrow_test.go => foreachrow_test.go | 2 +- impl/format.go => format.go | 2 +- impl/format_test.go => format_test.go | 2 +- impl/connection.go | 161 ------------- impl/errors.go | 42 ---- impl/reflectstruct.go | 23 -- impl/rowscanner.go | 128 ----------- impl/rowsscanner.go | 95 -------- mockconn/connection.go | 59 ++--- mockconn/connection_test.go | 31 +-- mockconn/onetimerowsprovider.go | 5 +- mockconn/row.go | 18 +- mockconn/row_test.go | 7 +- mockconn/rows.go | 4 +- mockconn/rows_test.go | 7 +- mockconn/rowsprovider.go | 5 +- mockconn/singlerowprovider.go | 10 +- mysqlconn/connection.go | 147 +++++++++++- {impl => mysqlconn}/transaction.go | 23 +- pqconn/connection.go | 20 +- pqconn/transaction.go | 20 +- reflection/columnfilter.go | 9 + {db => reflection}/reflectstruct.go | 26 +-- reflection/row.go | 13 ++ reflection/rows.go | 26 +++ {impl => reflection}/scanslice.go | 9 +- {impl => reflection}/scanstruct.go | 6 +- {impl => reflection}/scanstruct_test.go | 2 +- .../structfieldmapping.go | 7 +- .../structfieldmapping_test.go | 2 +- impl/row.go => row.go | 2 +- impl/rows.go => rows.go | 2 +- rowscanner.go | 127 ++++++++++ rowsscanner.go | 94 ++++++++ impl/scanresult.go => scanresult.go | 8 +- 46 files changed, 848 insertions(+), 832 deletions(-) create mode 100644 connectionwitherror.go rename impl/errors_test.go => errors_test.go (98%) rename impl/foreachrow.go => foreachrow.go (73%) rename impl/foreachrow_test.go => foreachrow_test.go (98%) rename impl/format.go => format.go (99%) rename impl/format_test.go => format_test.go (99%) delete mode 100644 impl/connection.go delete mode 100644 impl/errors.go delete mode 100644 impl/reflectstruct.go delete mode 100644 impl/rowscanner.go delete mode 100644 impl/rowsscanner.go rename {impl => mysqlconn}/transaction.go (82%) create mode 100644 reflection/columnfilter.go rename {db => reflection}/reflectstruct.go (77%) create mode 100644 reflection/row.go create mode 100644 reflection/rows.go rename {impl => reflection}/scanslice.go (97%) rename {impl => reflection}/scanstruct.go (88%) rename {impl => reflection}/scanstruct_test.go (98%) rename structfieldmapping.go => reflection/structfieldmapping.go (94%) rename structfieldmapping_test.go => reflection/structfieldmapping_test.go (99%) rename impl/row.go => row.go (96%) rename impl/rows.go => rows.go (98%) rename impl/scanresult.go => scanresult.go (87%) diff --git a/columnfilter.go b/columnfilter.go index a70ea89..fab3fe1 100644 --- a/columnfilter.go +++ b/columnfilter.go @@ -2,20 +2,22 @@ package sqldb import ( "reflect" + + "github.com/domonda/go-sqldb/reflection" ) type ColumnFilter interface { - IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool + IgnoreColumn(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool } -type ColumnFilterFunc func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool +type ColumnFilterFunc func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool -func (f ColumnFilterFunc) IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func (f ColumnFilterFunc) IgnoreColumn(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return f(name, flags, fieldType, fieldValue) } func IgnoreColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if name == ignore { return true @@ -26,7 +28,7 @@ func IgnoreColumns(names ...string) ColumnFilter { } func OnlyColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if name == include { return false @@ -37,7 +39,7 @@ func OnlyColumns(names ...string) ColumnFilter { } func IgnoreStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if fieldType.Name == ignore { return true @@ -48,7 +50,7 @@ func IgnoreStructFields(names ...string) ColumnFilter { } func OnlyStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if fieldType.Name == include { return false @@ -58,32 +60,32 @@ func OnlyStructFields(names ...string) ColumnFilter { }) } -func IgnoreFlags(ignore FieldFlag) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func IgnoreFlags(ignore reflection.FieldFlag) ColumnFilter { + return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags&ignore != 0 }) } -var IgnoreDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.Default() }) -var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.PrimaryKey() }) -var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.ReadOnly() }) -var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNull(fieldValue) }) -var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNullOrZero(fieldValue) }) -var IgnoreNullOrZeroDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNullOrZeroDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.Default() && IsNullOrZero(fieldValue) }) diff --git a/config.go b/config.go index 978225f..3e1ab65 100644 --- a/config.go +++ b/config.go @@ -6,8 +6,15 @@ import ( "fmt" "net/url" "time" + + "github.com/domonda/go-sqldb/reflection" ) +// DefaultStructFieldMapping provides the default StructFieldTagNaming +// using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. +// Implements StructFieldMapper. +var DefaultStructFieldMapping = reflection.NewTaggedStructFieldMapping() + // Config for a connection. // For tips see https://www.alexedwards.net/blog/configuring-sqldb type Config struct { diff --git a/connection.go b/connection.go index f22ad49..fe5de96 100644 --- a/connection.go +++ b/connection.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "time" + + "github.com/domonda/go-sqldb/reflection" ) type ( @@ -25,11 +27,11 @@ type Connection interface { WithContext(ctx context.Context) Connection // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldNamer. - WithStructFieldMapper(StructFieldMapper) Connection + // that will use the passed reflection.StructFieldMapper. + WithStructFieldMapper(reflection.StructFieldMapper) Connection // StructFieldMapper used by methods of this Connection. - StructFieldMapper() StructFieldMapper + StructFieldMapper() reflection.StructFieldMapper // Ping returns an error if the database // does not answer on this connection diff --git a/connectionwitherror.go b/connectionwitherror.go new file mode 100644 index 0000000..223c7dc --- /dev/null +++ b/connectionwitherror.go @@ -0,0 +1,205 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" + + "github.com/domonda/go-sqldb/reflection" +) + +// ConnectionWithError returns a dummy Connection +// where all methods return the passed error. +func ConnectionWithError(ctx context.Context, err error) Connection { + if err == nil { + panic("ConnectionWithError needs an error") + } + return connectionWithError{ctx, err} +} + +type connectionWithError struct { + ctx context.Context + err error +} + +func (e connectionWithError) Context() context.Context { return e.ctx } + +func (e connectionWithError) WithContext(ctx context.Context) Connection { + return connectionWithError{ctx: ctx, err: e.err} +} + +func (e connectionWithError) WithStructFieldMapper(reflection.StructFieldMapper) Connection { + return e +} + +func (e connectionWithError) StructFieldMapper() reflection.StructFieldMapper { + return DefaultStructFieldMapping +} + +func (e connectionWithError) Ping(time.Duration) error { + return e.err +} + +func (e connectionWithError) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (e connectionWithError) Config() *Config { + return &Config{Err: e.err} +} + +func (e connectionWithError) ValidateColumnName(name string) error { + return e.err +} + +func (e connectionWithError) ArgFmt() string { + return "" +} + +func (e connectionWithError) Err() error { + return e.err +} + +func (e connectionWithError) Now() (time.Time, error) { + return time.Time{}, e.err +} + +func (e connectionWithError) Exec(query string, args ...any) error { + return e.err +} + +func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { + return e.err +} + +func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { + return RowScannerWithError(e.err) +} + +func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { + return RowsScannerWithError(e.err) +} + +func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { + return e.err +} + +func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { + return e.err +} + +func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { + return RowScannerWithError(e.err) +} + +func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { + return RowsScannerWithError(e.err) +} + +func (e connectionWithError) IsTransaction() bool { + return false +} + +func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { + return nil, false +} + +func (e connectionWithError) Begin(opts *sql.TxOptions) (Connection, error) { + return nil, e.err +} + +func (e connectionWithError) Commit() error { + return e.err +} + +func (e connectionWithError) Rollback() error { + return e.err +} + +func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { + return e.err +} + +func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return e.err +} + +func (e connectionWithError) UnlistenChannel(channel string) error { + return e.err +} + +func (e connectionWithError) IsListeningOnChannel(channel string) bool { + return false +} + +func (e connectionWithError) Close() error { + return e.err +} + +// RowScannerWithError + +// RowScannerWithError returns a dummy RowScanner +// where all methods return the passed error. +func RowScannerWithError(err error) RowScanner { + return rowScannerWithError{err} +} + +type rowScannerWithError struct { + err error +} + +func (e rowScannerWithError) Scan(dest ...any) error { + return e.err +} + +func (e rowScannerWithError) ScanStruct(dest any) error { + return e.err +} + +func (e rowScannerWithError) ScanValues() ([]any, error) { + return nil, e.err +} + +func (e rowScannerWithError) ScanStrings() ([]string, error) { + return nil, e.err +} + +func (e rowScannerWithError) Columns() ([]string, error) { + return nil, e.err +} + +// RowsScannerWithError + +// RowsScannerWithError returns a dummy RowsScanner +// where all methods return the passed error. +func RowsScannerWithError(err error) RowsScanner { + return rowsScannerWithError{err} +} + +type rowsScannerWithError struct { + err error +} + +func (e rowsScannerWithError) ScanSlice(dest any) error { + return e.err +} + +func (e rowsScannerWithError) ScanStructSlice(dest any) error { + return e.err +} + +func (e rowsScannerWithError) Columns() ([]string, error) { + return nil, e.err +} + +func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { + return nil, e.err +} + +func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { + return e.err +} + +func (e rowsScannerWithError) ForEachRowCall(callback any) error { + return e.err +} diff --git a/db/insert.go b/db/insert.go index 05ba3d2..4676fa7 100644 --- a/db/insert.go +++ b/db/insert.go @@ -7,12 +7,12 @@ import ( "strings" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" + "github.com/domonda/go-sqldb/reflection" ) type Values = sqldb.Values -var WrapNonNilErrorWithQuery = impl.WrapNonNilErrorWithQuery +var WrapNonNilErrorWithQuery = sqldb.WrapNonNilErrorWithQuery // Insert a new row into table using the values. func Insert(ctx context.Context, table string, values Values) error { @@ -78,7 +78,7 @@ func InsertReturning(ctx context.Context, table string, values Values, returning // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { +func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { conn := Conn(ctx) argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() @@ -102,7 +102,7 @@ func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.Col // Optional ColumnFilter can be passed to ignore mapped columns. // Does nothing if the onConflict statement applies // and returns if a row was inserted. -func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { +func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...reflection.ColumnFilter) (inserted bool, err error) { conn := Conn(ctx) argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() @@ -147,7 +147,7 @@ func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) w.WriteByte(')') } -func insertStructValues(rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (table string, columns []string, vals []any, err error) { +func insertStructValues(rowStruct any, mapper reflection.StructFieldMapper, ignoreColumns []reflection.ColumnFilter) (table string, columns []string, vals []any, err error) { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -159,6 +159,6 @@ func insertStructValues(rowStruct any, mapper sqldb.StructFieldMapper, ignoreCol return "", nil, nil, fmt.Errorf("expected struct but got %T", rowStruct) } - table, columns, _, vals, err = ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + table, columns, _, vals, err = reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) return table, columns, vals, err } diff --git a/db/update.go b/db/update.go index 10e86e9..da2b450 100644 --- a/db/update.go +++ b/db/update.go @@ -2,10 +2,13 @@ package db import ( "context" + "errors" "fmt" + "reflect" "strings" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" "golang.org/x/exp/slices" ) @@ -69,7 +72,7 @@ func buildUpdateQuery(table string, values sqldb.Values, where, argFmt string, a // Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { +func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { v, err := derefStruct(rowStruct) if err != nil { return err @@ -78,7 +81,7 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.Col conn := Conn(ctx) argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() - table, columns, pkCols, vals, err := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { return err } @@ -118,3 +121,17 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.Col return WrapNonNilErrorWithQuery(err, query, argFmt, vals) } + +func derefStruct(rowStruct any) (reflect.Value, error) { + v := reflect.ValueOf(rowStruct) + for v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + switch { + case v.Kind() == reflect.Ptr && v.IsNil(): + return reflect.Value{}, errors.New("can't use nil pointer") + case v.Kind() != reflect.Struct: + return reflect.Value{}, fmt.Errorf("expected struct but got %T", rowStruct) + } + return v, nil +} diff --git a/db/upsert.go b/db/upsert.go index 3c03a8e..6b9e9fa 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" "golang.org/x/exp/slices" ) @@ -16,7 +17,7 @@ import ( // The struct must have at least one field with a `db` tag value having a ",pk" suffix // to mark primary key column(s). // If inserting conflicts on the primary key column(s), then an update is performed. -func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { +func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { v, err := derefStruct(rowStruct) if err != nil { return err @@ -25,7 +26,7 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...sqldb.Col conn := Conn(ctx) argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() - table, columns, pkCols, vals, err := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { return err } diff --git a/errors.go b/errors.go index d0ed89f..c082b92 100644 --- a/errors.go +++ b/errors.go @@ -1,10 +1,9 @@ package sqldb import ( - "context" "database/sql" "errors" - "time" + "fmt" ) var ( @@ -58,200 +57,40 @@ const ( ErrNotSupported sentinelError = "not supported" ) -// ConnectionWithError - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { - if err == nil { - panic("ConnectionWithError needs an error") +// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query +// if the error was not already wrapped with a query. +// If the passed error is nil, then nil will be returned. +func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { + var wrapped errWithQuery + if err == nil || errors.As(err, &wrapped) { + return err } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) ArgFmt() string { - return "" -} - -func (e connectionWithError) Err() error { - return e.err -} - -func (e connectionWithError) Now() (time.Time, error) { - return time.Time{}, e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) + return errWithQuery{err, query, argFmt, args} } -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err +type errWithQuery struct { + err error + query string + argFmt string + args []any } -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions) (Connection, error) { - return nil, e.err -} +func (e errWithQuery) Unwrap() error { return e.err } -func (e connectionWithError) Commit() error { - return e.err +func (e errWithQuery) Error() string { + return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) } -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err +func combineErrors(prim, sec error) error { + switch { + case prim != nil && sec != nil: + return fmt.Errorf("%w\n%s", prim, sec) + case prim != nil: + return prim + case sec != nil: + return sec + } + return nil } -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err -} +// ConnectionWithError diff --git a/impl/errors_test.go b/errors_test.go similarity index 98% rename from impl/errors_test.go rename to errors_test.go index 6e544e5..9c9bed6 100644 --- a/impl/errors_test.go +++ b/errors_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql" diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index 3a76761..9f0a17f 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -10,6 +10,7 @@ import ( "github.com/domonda/go-sqldb" "github.com/domonda/go-sqldb/db" "github.com/domonda/go-sqldb/pqconn" + "github.com/domonda/go-sqldb/reflection" "github.com/domonda/go-types/email" "github.com/domonda/go-types/nullable" "github.com/domonda/go-types/uu" @@ -45,10 +46,10 @@ func main() { panic(err) } - conn = conn.WithStructFieldMapper(&sqldb.TaggedStructFieldMapping{ + conn = conn.WithStructFieldMapper(&reflection.TaggedStructFieldMapping{ NameTag: "col", Ignore: "ignore", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, }) var users []User diff --git a/impl/foreachrow.go b/foreachrow.go similarity index 73% rename from impl/foreachrow.go rename to foreachrow.go index 2be2b6b..ab639a4 100644 --- a/impl/foreachrow.go +++ b/foreachrow.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "context" @@ -6,8 +6,6 @@ import ( "fmt" "reflect" "time" - - sqldb "github.com/domonda/go-sqldb" ) var ( @@ -27,17 +25,17 @@ var ( // If a non nil error is returned from the callback, then this error // is returned immediately by this function without scanning further rows. // In case of zero rows, no error will be returned. -func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScanner) error, err error) { +func forEachRowCallFunc(ctx context.Context, callback any) (f func(RowScanner) error, err error) { val := reflect.ValueOf(callback) typ := val.Type() if typ.Kind() != reflect.Func { - return nil, fmt.Errorf("ForEachRowCall expected callback function, got %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc expected callback function, got %s", typ) } if typ.IsVariadic() { - return nil, fmt.Errorf("ForEachRowCall callback function must not be varidic: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function must not be varidic: %s", typ) } if typ.NumIn() == 0 || (typ.NumIn() == 1 && typ.In(0) == typeOfContext) { - return nil, fmt.Errorf("ForEachRowCall callback function has no arguments: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function has no arguments: %s", typ) } firstArg := 0 if typ.In(0) == typeOfContext { @@ -58,21 +56,21 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan continue } if structArg { - return nil, fmt.Errorf("ForEachRowCall callback function must not have further argument after struct: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function must not have further argument after struct: %s", typ) } structArg = true case reflect.Chan, reflect.Func: - return nil, fmt.Errorf("ForEachRowCall callback function has invalid argument type: %s", typ.In(i)) + return nil, fmt.Errorf("ForEachRowCallFunc callback function has invalid argument type: %s", typ.In(i)) } } if typ.NumOut() > 1 { - return nil, fmt.Errorf("ForEachRowCall callback function can only have one result value: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function can only have one result value: %s", typ) } if typ.NumOut() == 1 && typ.Out(0) != typeOfError { - return nil, fmt.Errorf("ForEachRowCall callback function result must be of type error: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function result must be of type error: %s", typ) } - f = func(row sqldb.RowScanner) (err error) { + f = func(row RowScanner) (err error) { // First scan row scannedValPtrs := make([]any, typ.NumIn()-firstArg) for i := range scannedValPtrs { diff --git a/impl/foreachrow_test.go b/foreachrow_test.go similarity index 98% rename from impl/foreachrow_test.go rename to foreachrow_test.go index 7509553..5f0f3cb 100644 --- a/impl/foreachrow_test.go +++ b/foreachrow_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "testing" diff --git a/impl/format.go b/format.go similarity index 99% rename from impl/format.go rename to format.go index 5f78ba3..cb46958 100644 --- a/impl/format.go +++ b/format.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" diff --git a/impl/format_test.go b/format_test.go similarity index 99% rename from impl/format_test.go rename to format_test.go index 8d6ab7c..d30e394 100644 --- a/impl/format_test.go +++ b/format_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" diff --git a/impl/connection.go b/impl/connection.go deleted file mode 100644 index d330fa7..0000000 --- a/impl/connection.go +++ /dev/null @@ -1,161 +0,0 @@ -package impl - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/domonda/go-sqldb" -) - -// Connection returns a generic sqldb.Connection implementation -// for an existing sql.DB connection. -// argFmt is the format string for argument placeholders like "?" or "$%d" -// that will be replaced error messages to format a complete query. -func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateColumnName func(string) error, argFmt string) sqldb.Connection { - return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldMapper: sqldb.DefaultStructFieldMapping, - argFmt: argFmt, - validateColumnName: validateColumnName, - } -} - -type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldMapper sqldb.StructFieldMapper - argFmt string - validateColumnName func(string) error -} - -func (conn *connection) clone() *connection { - c := *conn - return &c -} - -func (conn *connection) Context() context.Context { return conn.ctx } - -func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.ctx { - return conn - } - c := conn.clone() - c.ctx = ctx - return c -} - -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldMapper = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldMapper -} - -func (conn *connection) Ping(timeout time.Duration) error { - ctx := conn.ctx - if timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - return conn.db.PingContext(ctx) -} - -func (conn *connection) Stats() sql.DBStats { - return conn.db.Stats() -} - -func (conn *connection) Config() *sqldb.Config { - return conn.config -} - -func (conn *connection) ValidateColumnName(name string) error { - return conn.validateColumnName(name) -} - -func (conn *connection) ArgFmt() string { - return conn.argFmt -} - -func (conn *connection) Err() error { - return nil -} - -func (conn *connection) Now() (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldMapper, query, conn.argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn.argFmt, args) -} - -func (conn *connection) IsTransaction() bool { - return false -} - -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { - tx, err := conn.db.BeginTx(conn.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn, tx, opts), nil -} - -func (conn *connection) Commit() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) Rollback() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *connection) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *connection) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *connection) Close() error { - return conn.db.Close() -} diff --git a/impl/errors.go b/impl/errors.go deleted file mode 100644 index b6028a5..0000000 --- a/impl/errors.go +++ /dev/null @@ -1,42 +0,0 @@ -package impl - -import ( - "errors" - "fmt" -) - -// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// if the error was not already wrapped with a query. -// If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { - var wrapped errWithQuery - if err == nil || errors.As(err, &wrapped) { - return err - } - return errWithQuery{err, query, argFmt, args} -} - -type errWithQuery struct { - err error - query string - argFmt string - args []any -} - -func (e errWithQuery) Unwrap() error { return e.err } - -func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) -} - -func combineErrors(prim, sec error) error { - switch { - case prim != nil && sec != nil: - return fmt.Errorf("%w\n%s", prim, sec) - case prim != nil: - return prim - case sec != nil: - return sec - } - return nil -} diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go deleted file mode 100644 index 8860f69..0000000 --- a/impl/reflectstruct.go +++ /dev/null @@ -1,23 +0,0 @@ -package impl - -import ( - "reflect" - - "github.com/domonda/go-sqldb" -) - -func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { - panic("TODO remove") -} - -func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { - panic("TODO remove") -} - -func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { - panic("TODO remove") -} - -func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - panic("TODO remove") -} diff --git a/impl/rowscanner.go b/impl/rowscanner.go deleted file mode 100644 index 53956b7..0000000 --- a/impl/rowscanner.go +++ /dev/null @@ -1,128 +0,0 @@ -package impl - -import ( - "database/sql" - - sqldb "github.com/domonda/go-sqldb" -) - -var ( - _ sqldb.RowScanner = &RowScanner{} - _ sqldb.RowScanner = CurrentRowScanner{} - _ sqldb.RowScanner = SingleRowScanner{} -) - -// RowScanner implements sqldb.RowScanner for a sql.Row -type RowScanner struct { - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowScanner { - return &RowScanner{rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowScanner) Scan(dest ...any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *RowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return ScanStruct(s.rows, dest, s.structFieldNamer) -} - -func (s *RowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *RowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *RowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close -type CurrentRowScanner struct { - Rows Rows - StructFieldMapper sqldb.StructFieldMapper -} - -func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) -} - -func (s CurrentRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Rows, dest, s.StructFieldMapper) -} - -func (s CurrentRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Rows) -} - -func (s CurrentRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Rows) -} - -func (s CurrentRowScanner) Columns() ([]string, error) { - return s.Rows.Columns() -} - -// SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper sqldb.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} diff --git a/impl/rowsscanner.go b/impl/rowsscanner.go deleted file mode 100644 index f833e59..0000000 --- a/impl/rowsscanner.go +++ /dev/null @@ -1,95 +0,0 @@ -package impl - -import ( - "context" - "fmt" - - sqldb "github.com/domonda/go-sqldb" -) - -var _ sqldb.RowsScanner = &RowsScanner{} - -// RowsScanner implements sqldb.RowsScanner with Rows -type RowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowsScanner { - return &RowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowsScanner) ScanSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) ScanStructSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *RowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *RowsScanner) ForEachRow(callback func(sqldb.RowScanner) error) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *RowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} diff --git a/mockconn/connection.go b/mockconn/connection.go index 694322d..550b870 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -8,28 +8,29 @@ import ( "time" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) var DefaultArgFmt = "$%d" func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ - ctx: ctx, - queryWriter: queryWriter, - listening: newBoolMap(), - rowsProvider: rowsProvider, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: DefaultArgFmt, + ctx: ctx, + queryWriter: queryWriter, + listening: newBoolMap(), + rowsProvider: rowsProvider, + structFieldMapper: sqldb.DefaultStructFieldMapping, + argFmt: DefaultArgFmt, } } type connection struct { - ctx context.Context - queryWriter io.Writer - listening *boolMap - rowsProvider RowsProvider - structFieldNamer sqldb.StructFieldMapper - argFmt string + ctx context.Context + queryWriter io.Writer + listening *boolMap + rowsProvider RowsProvider + structFieldMapper reflection.StructFieldMapper + argFmt string } func (conn *connection) Context() context.Context { return conn.ctx } @@ -39,28 +40,28 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return conn } return &connection{ - ctx: ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: conn.structFieldNamer, - argFmt: conn.argFmt, + ctx: ctx, + queryWriter: conn.queryWriter, + listening: conn.listening, + rowsProvider: conn.rowsProvider, + structFieldMapper: conn.structFieldMapper, + argFmt: conn.argFmt, } } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { return &connection{ - ctx: conn.ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: namer, - argFmt: conn.argFmt, + ctx: conn.ctx, + queryWriter: conn.queryWriter, + listening: conn.listening, + rowsProvider: conn.rowsProvider, + structFieldMapper: mapper, + argFmt: conn.argFmt, } } -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer +func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { + return conn.structFieldMapper } func (conn *connection) Stats() sql.DBStats { @@ -108,7 +109,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { if conn.rowsProvider == nil { return sqldb.RowScannerWithError(nil) } - return conn.rowsProvider.QueryRow(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRow(conn.structFieldMapper, query, args...) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -121,7 +122,7 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { if conn.rowsProvider == nil { return sqldb.RowsScannerWithError(nil) } - return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRows(conn.structFieldMapper, query, args...) } func (conn *connection) IsTransaction() bool { diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index cd646e2..0eb2366 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -11,6 +11,7 @@ import ( sqldb "github.com/domonda/go-sqldb" "github.com/domonda/go-sqldb/db" + "github.com/domonda/go-sqldb/reflection" "github.com/domonda/go-types/uu" ) @@ -34,7 +35,7 @@ type testRow struct { func TestInsertQuery(t *testing.T) { context.Background() - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} queryOutput := bytes.NewBuffer(nil) rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) @@ -67,13 +68,13 @@ func TestInsertQuery(t *testing.T) { func TestInsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) @@ -99,13 +100,13 @@ func TestInsertStructQuery(t *testing.T) { func TestInsertUniqueStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) @@ -135,7 +136,7 @@ func TestInsertUniqueStructQuery(t *testing.T) { func TestUpdateQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" @@ -164,7 +165,7 @@ func TestUpdateQuery(t *testing.T) { func TestUpdateReturningQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" @@ -193,13 +194,13 @@ func TestUpdateReturningQuery(t *testing.T) { func TestUpdateStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) @@ -225,13 +226,13 @@ func TestUpdateStructQuery(t *testing.T) { func TestUpsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) @@ -254,13 +255,13 @@ type multiPrimaryKeyRow struct { func TestUpsertStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) @@ -274,13 +275,13 @@ func TestUpsertStructMultiPKQuery(t *testing.T) { func TestUpdateStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) diff --git a/mockconn/onetimerowsprovider.go b/mockconn/onetimerowsprovider.go index 51ca5ce..11a6822 100644 --- a/mockconn/onetimerowsprovider.go +++ b/mockconn/onetimerowsprovider.go @@ -7,6 +7,7 @@ import ( "sync" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type OneTimeRowsProvider struct { @@ -44,7 +45,7 @@ func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, que p.rowsScanners[key] = scanner } -func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +func (p *OneTimeRowsProvider) QueryRow(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { p.mtx.Lock() defer p.mtx.Unlock() @@ -54,7 +55,7 @@ func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, return scanner } -func (p *OneTimeRowsProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +func (p *OneTimeRowsProvider) QueryRows(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/mockconn/row.go b/mockconn/row.go index 928b47e..054d0fc 100644 --- a/mockconn/row.go +++ b/mockconn/row.go @@ -9,35 +9,35 @@ import ( "strconv" "time" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) // Row implements impl.Row with the fields of a struct as column values. type Row struct { - rowStructVal reflect.Value - columnNamer sqldb.StructFieldMapper + rowStructVal reflect.Value + structFieldMapper reflection.StructFieldMapper } -func NewRow(rowStruct any, columnNamer sqldb.StructFieldMapper) *Row { +func NewRow(rowStruct any, mapper reflection.StructFieldMapper) *Row { val := reflect.ValueOf(rowStruct) for val.Kind() == reflect.Ptr { val = val.Elem() } return &Row{ - rowStructVal: val, - columnNamer: columnNamer, + rowStructVal: val, + structFieldMapper: mapper, } } -func (r *Row) StructFieldMapper() sqldb.StructFieldMapper { - return r.columnNamer +func (r *Row) StructFieldMapper() reflection.StructFieldMapper { + return r.structFieldMapper } func (r *Row) Columns() ([]string, error) { columns := make([]string, r.rowStructVal.NumField()) for i := range columns { field := r.rowStructVal.Type().Field(i) - _, columns[i], _, _ = r.columnNamer.MapStructField(field) + _, columns[i], _, _ = r.structFieldMapper.MapStructField(field) } return columns, nil } diff --git a/mockconn/row_test.go b/mockconn/row_test.go index f63ec70..ee0c8bc 100644 --- a/mockconn/row_test.go +++ b/mockconn/row_test.go @@ -3,10 +3,9 @@ package mockconn import ( "testing" + "github.com/domonda/go-sqldb/reflection" "github.com/lib/pq" "github.com/stretchr/testify/assert" - - sqldb "github.com/domonda/go-sqldb" ) func TestRow(t *testing.T) { @@ -21,8 +20,8 @@ func TestRow(t *testing.T) { str := "Hello World!" input := Struct{"myID", 66, -1, &str, nil, pq.BoolArray{true, false}} - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - row := NewRow(input, naming) + mapping := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + row := NewRow(input, mapping) cols, err := row.Columns() assert.NoError(t, err) diff --git a/mockconn/rows.go b/mockconn/rows.go index 8698fe5..14da2e0 100644 --- a/mockconn/rows.go +++ b/mockconn/rows.go @@ -4,7 +4,7 @@ import ( "errors" "reflect" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type Rows struct { @@ -14,7 +14,7 @@ type Rows struct { err error } -func NewRowsFromStructs(rowStructs any, columnNamer sqldb.StructFieldMapper) *Rows { +func NewRowsFromStructs(rowStructs any, columnNamer reflection.StructFieldMapper) *Rows { v := reflect.ValueOf(rowStructs) t := v.Type() if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { diff --git a/mockconn/rows_test.go b/mockconn/rows_test.go index 608cbb1..f64ec9f 100644 --- a/mockconn/rows_test.go +++ b/mockconn/rows_test.go @@ -4,10 +4,9 @@ import ( "fmt" "testing" + "github.com/domonda/go-sqldb/reflection" "github.com/lib/pq" "github.com/stretchr/testify/assert" - - sqldb "github.com/domonda/go-sqldb" ) func TestRows(t *testing.T) { @@ -26,8 +25,8 @@ func TestRows(t *testing.T) { input = append(input, &Struct{"myID", i, -1, &str, nil, pq.BoolArray{true, false, i%2 == 0}}) } - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - rows := NewRowsFromStructs(input, naming) + mapping := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + rows := NewRowsFromStructs(input, mapping) cols, err := rows.Columns() assert.NoError(t, err) diff --git a/mockconn/rowsprovider.go b/mockconn/rowsprovider.go index 9cf8394..3c08384 100644 --- a/mockconn/rowsprovider.go +++ b/mockconn/rowsprovider.go @@ -2,9 +2,10 @@ package mockconn import ( sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type RowsProvider interface { - QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner - QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner + QueryRow(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner + QueryRows(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner } diff --git a/mockconn/singlerowprovider.go b/mockconn/singlerowprovider.go index 8c8e39e..68cd8f0 100644 --- a/mockconn/singlerowprovider.go +++ b/mockconn/singlerowprovider.go @@ -4,7 +4,7 @@ import ( "context" sqldb "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" + "github.com/domonda/go-sqldb/reflection" ) // NewSingleRowProvider a RowsProvider implementation @@ -20,10 +20,10 @@ type singleRowProvider struct { argFmt string } -func (p *singleRowProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { - return impl.NewRowScanner(impl.RowAsRows(p.row), structFieldNamer, query, p.argFmt, args) +func (p *singleRowProvider) QueryRow(mapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { + return sqldb.NewRowScanner(sqldb.RowAsRows(p.row), mapper, query, p.argFmt, args) } -func (p *singleRowProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { - return impl.NewRowsScanner(context.Background(), NewRows(p.row), structFieldNamer, query, p.argFmt, args) +func (p *singleRowProvider) QueryRows(mapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { + return sqldb.NewRowsScanner(context.Background(), NewRows(p.row), mapper, query, p.argFmt, args) } diff --git a/mysqlconn/connection.go b/mysqlconn/connection.go index 46be636..9c0435d 100644 --- a/mysqlconn/connection.go +++ b/mysqlconn/connection.go @@ -4,9 +4,10 @@ import ( "context" "database/sql" "fmt" + "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" + "github.com/domonda/go-sqldb/reflection" ) // New creates a new sqldb.Connection using the passed sqldb.Config @@ -23,7 +24,14 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { if err != nil { return nil, err } - return impl.Connection(ctx, db, config, validateColumnName, argFmt), nil + conn := &connection{ + ctx: ctx, + db: db, + config: config, + structFieldMapper: sqldb.DefaultStructFieldMapping, + argFmt: argFmt, + } + return conn, nil } // MustNew creates a new sqldb.Connection using the passed sqldb.Config @@ -38,3 +46,138 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } return conn } + +type connection struct { + ctx context.Context + db *sql.DB + config *sqldb.Config + structFieldMapper reflection.StructFieldMapper + argFmt string +} + +func (conn *connection) clone() *connection { + c := *conn + return &c +} + +func (conn *connection) Context() context.Context { return conn.ctx } + +func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { + if ctx == conn.ctx { + return conn + } + c := conn.clone() + c.ctx = ctx + return c +} + +func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { + c := conn.clone() + c.structFieldMapper = mapper + return c +} + +func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { + return conn.structFieldMapper +} + +func (conn *connection) Ping(timeout time.Duration) error { + ctx := conn.ctx + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return conn.db.PingContext(ctx) +} + +func (conn *connection) Stats() sql.DBStats { + return conn.db.Stats() +} + +func (conn *connection) Config() *sqldb.Config { + return conn.config +} + +func (conn *connection) ValidateColumnName(name string) error { + return validateColumnName(name) +} + +func (conn *connection) ArgFmt() string { + return conn.argFmt +} + +func (conn *connection) Err() error { + return nil +} + +func (conn *connection) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil +} + +func (conn *connection) Exec(query string, args ...any) error { + _, err := conn.db.ExecContext(conn.ctx, query, args...) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) +} + +func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { + rows, err := conn.db.QueryContext(conn.ctx, query, args...) + if err != nil { + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + return sqldb.RowScannerWithError(err) + } + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn.argFmt, args) +} + +func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { + rows, err := conn.db.QueryContext(conn.ctx, query, args...) + if err != nil { + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + return sqldb.RowsScannerWithError(err) + } + return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn.argFmt, args) +} + +func (conn *connection) IsTransaction() bool { + return false +} + +func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { + return nil, false +} + +func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { + tx, err := conn.db.BeginTx(conn.ctx, opts) + if err != nil { + return nil, err + } + return newTransaction(conn, tx, opts), nil +} + +func (conn *connection) Commit() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) Rollback() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *connection) UnlistenChannel(channel string) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *connection) IsListeningOnChannel(channel string) bool { + return false +} + +func (conn *connection) Close() error { + return conn.db.Close() +} diff --git a/impl/transaction.go b/mysqlconn/transaction.go similarity index 82% rename from impl/transaction.go rename to mysqlconn/transaction.go index cc2be4e..95d4a03 100644 --- a/impl/transaction.go +++ b/mysqlconn/transaction.go @@ -1,4 +1,4 @@ -package impl +package mysqlconn import ( "context" @@ -7,6 +7,7 @@ import ( "time" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type transaction struct { @@ -15,7 +16,7 @@ type transaction struct { parent *connection tx *sql.Tx opts *sql.TxOptions - structFieldMapper sqldb.StructFieldMapper + structFieldMapper reflection.StructFieldMapper } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { @@ -43,13 +44,13 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts) } -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *transaction) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldMapper = namer + c.structFieldMapper = mapper return c } -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { +func (conn *transaction) StructFieldMapper() reflection.StructFieldMapper { return conn.structFieldMapper } @@ -58,7 +59,7 @@ func (conn *transaction) Stats() sql.DBStats { return conn.parent. func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } func (conn *transaction) ValidateColumnName(name string) error { - return conn.parent.validateColumnName(name) + return validateColumnName(name) } func (conn *transaction) ArgFmt() string { @@ -79,25 +80,25 @@ func (conn *transaction) Now() (now time.Time, err error) { func (conn *transaction) Exec(query string, args ...any) error { _, err := conn.tx.Exec(query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) return sqldb.RowScannerWithError(err) } - return NewRowScanner(rows, conn.structFieldMapper, query, conn.parent.argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn.parent.argFmt, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) return sqldb.RowsScannerWithError(err) } - return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn.parent.argFmt, args) + return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn.parent.argFmt, args) } func (conn *transaction) IsTransaction() bool { diff --git a/pqconn/connection.go b/pqconn/connection.go index 4502412..9ca8c85 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -7,7 +7,7 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" + "github.com/domonda/go-sqldb/reflection" ) const argFmt = "$%d" @@ -51,7 +51,7 @@ type connection struct { ctx context.Context db *sql.DB config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper + structFieldNamer reflection.StructFieldMapper } func (conn *connection) clone() *connection { @@ -70,13 +70,13 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return c } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = namer + c.structFieldNamer = mapper return c } -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { +func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { return conn.structFieldNamer } @@ -120,25 +120,25 @@ func (conn *connection) Now() (now time.Time, err error) { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowScannerWithError(err) } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowsScannerWithError(err) } - return impl.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) } func (conn *connection) IsTransaction() bool { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index c9a4f29..624e10b 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -6,7 +6,7 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" + "github.com/domonda/go-sqldb/reflection" ) type transaction struct { @@ -15,7 +15,7 @@ type transaction struct { parent *connection tx *sql.Tx opts *sql.TxOptions - structFieldNamer sqldb.StructFieldMapper + structFieldNamer reflection.StructFieldMapper } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { @@ -43,13 +43,13 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts) } -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *transaction) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = namer + c.structFieldNamer = mapper return c } -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { +func (conn *transaction) StructFieldMapper() reflection.StructFieldMapper { return conn.structFieldNamer } @@ -79,25 +79,25 @@ func (conn *transaction) Now() (now time.Time, err error) { func (conn *transaction) Exec(query string, args ...any) error { _, err := conn.tx.Exec(query, args...) - return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) } func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowScannerWithError(err) } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowsScannerWithError(err) } - return impl.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) } func (conn *transaction) IsTransaction() bool { diff --git a/reflection/columnfilter.go b/reflection/columnfilter.go new file mode 100644 index 0000000..a94321b --- /dev/null +++ b/reflection/columnfilter.go @@ -0,0 +1,9 @@ +package reflection + +import ( + "reflect" +) + +type ColumnFilter interface { + IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool +} diff --git a/db/reflectstruct.go b/reflection/reflectstruct.go similarity index 77% rename from db/reflectstruct.go rename to reflection/reflectstruct.go index 7df5bb2..1c5e36d 100644 --- a/db/reflectstruct.go +++ b/reflection/reflectstruct.go @@ -1,4 +1,4 @@ -package db +package reflection import ( "errors" @@ -7,11 +7,9 @@ import ( "strings" "golang.org/x/exp/slices" - - "github.com/domonda/go-sqldb" ) -func ReflectStructValues(structVal reflect.Value, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { +func ReflectStructValues(structVal reflect.Value, mapper StructFieldMapper, ignoreColumns []ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { structType := structVal.Type() for i := 0; i < structType.NumField(); i++ { fieldType := structType.Field(i) @@ -60,7 +58,7 @@ func ReflectStructValues(structVal reflect.Value, mapper sqldb.StructFieldMapper return table, columns, pkCols, values, nil } -func ReflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { +func ReflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string) (pointers []any, err error) { if len(columns) == 0 { return nil, errors.New("no columns") } @@ -88,7 +86,7 @@ func ReflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFie return pointers, nil } -func reflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string, pointers []any) error { +func reflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string, pointers []any) error { var ( structType = structVal.Type() ) @@ -123,7 +121,7 @@ func reflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFie return nil } -func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func ignoreColumn(filters []ColumnFilter, name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, filter := range filters { if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { return true @@ -131,17 +129,3 @@ func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFl } return false } - -func derefStruct(rowStruct any) (reflect.Value, error) { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return reflect.Value{}, errors.New("can't use nil pointer") - case v.Kind() != reflect.Struct: - return reflect.Value{}, fmt.Errorf("expected struct but got %T", rowStruct) - } - return v, nil -} diff --git a/reflection/row.go b/reflection/row.go new file mode 100644 index 0000000..ba394e0 --- /dev/null +++ b/reflection/row.go @@ -0,0 +1,13 @@ +package reflection + +// Row is an interface with the methods of sql.Rows +// that are needed for ScanStruct. +// Allows mocking for tests without an SQL driver. +type Row interface { + // Columns returns the column names. + Columns() ([]string, error) + // Scan copies the columns in the current row into the values pointed + // at by dest. The number of values in dest must be the same as the + // number of columns in Rows. + Scan(dest ...any) error +} diff --git a/reflection/rows.go b/reflection/rows.go new file mode 100644 index 0000000..2133ee3 --- /dev/null +++ b/reflection/rows.go @@ -0,0 +1,26 @@ +package reflection + +// Rows is an interface with the methods of sql.Rows +// that are needed for ScanSlice. +// Allows mocking for tests without an SQL driver. +type Rows interface { + Row + + // Close closes the Rows, preventing further enumeration. If Next is called + // and returns false and there are no further result sets, + // the Rows are closed automatically and it will suffice to check the + // result of Err. Close is idempotent and does not affect the result of Err. + Close() error + + // Next prepares the next result row for reading with the Scan method. It + // returns true on success, or false if there is no next result row or an error + // happened while preparing it. Err should be consulted to distinguish between + // the two cases. + // + // Every call to Scan, even the first one, must be preceded by a call to Next. + Next() bool + + // Err returns the error, if any, that was encountered during iteration. + // Err may be called after an explicit or implicit Close. + Err() error +} diff --git a/impl/scanslice.go b/reflection/scanslice.go similarity index 97% rename from impl/scanslice.go rename to reflection/scanslice.go index 1aead27..c4d7dcc 100644 --- a/impl/scanslice.go +++ b/reflection/scanslice.go @@ -1,4 +1,4 @@ -package impl +package reflection import ( "context" @@ -8,17 +8,20 @@ import ( "reflect" "time" - sqldb "github.com/domonda/go-sqldb" "github.com/domonda/go-types/nullable" ) +var ( + typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +) + // ScanRowsAsSlice scans all srcRows as slice into dest. // The rows must either have only one column compatible with the element type of the slice, // or if multiple columns are returned then the slice element type must me a struct or struction pointer // so that every column maps on exactly one struct field using structFieldNamer. // In case of single column rows, nil must be passed for structFieldNamer. // ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error { +func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer StructFieldMapper) error { defer srcRows.Close() destVal := reflect.ValueOf(dest) diff --git a/impl/scanstruct.go b/reflection/scanstruct.go similarity index 88% rename from impl/scanstruct.go rename to reflection/scanstruct.go index ce2dd56..3c5a77a 100644 --- a/impl/scanstruct.go +++ b/reflection/scanstruct.go @@ -1,13 +1,11 @@ -package impl +package reflection import ( "fmt" "reflect" - - sqldb "github.com/domonda/go-sqldb" ) -func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error { +func ScanStruct(srcRow Row, destStruct any, namer StructFieldMapper) error { v := reflect.ValueOf(destStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() diff --git a/impl/scanstruct_test.go b/reflection/scanstruct_test.go similarity index 98% rename from impl/scanstruct_test.go rename to reflection/scanstruct_test.go index 787e76e..d81e484 100644 --- a/impl/scanstruct_test.go +++ b/reflection/scanstruct_test.go @@ -1,4 +1,4 @@ -package impl +package reflection // func TestGetStructFieldIndices(t *testing.T) { // type DeepEmbeddedStruct struct { diff --git a/structfieldmapping.go b/reflection/structfieldmapping.go similarity index 94% rename from structfieldmapping.go rename to reflection/structfieldmapping.go index fde2ca9..c0c836e 100644 --- a/structfieldmapping.go +++ b/reflection/structfieldmapping.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "fmt" @@ -54,11 +54,6 @@ func NewTaggedStructFieldMapping() *TaggedStructFieldMapping { } } -// DefaultStructFieldMapping provides the default StructFieldTagNaming -// using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. -// Implements StructFieldMapper. -var DefaultStructFieldMapping = NewTaggedStructFieldMapping() - // TaggedStructFieldMapping implements StructFieldMapper with a struct field NameTag // to be used for naming and a UntaggedNameFunc in case the NameTag is not set. type TaggedStructFieldMapping struct { diff --git a/structfieldmapping_test.go b/reflection/structfieldmapping_test.go similarity index 99% rename from structfieldmapping_test.go rename to reflection/structfieldmapping_test.go index 4fcece5..2473639 100644 --- a/structfieldmapping_test.go +++ b/reflection/structfieldmapping_test.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "reflect" diff --git a/impl/row.go b/row.go similarity index 96% rename from impl/row.go rename to row.go index a3b34c7..a82431a 100644 --- a/impl/row.go +++ b/row.go @@ -1,4 +1,4 @@ -package impl +package sqldb // Row is an interface with the methods of sql.Rows // that are needed for ScanStruct. diff --git a/impl/rows.go b/rows.go similarity index 98% rename from impl/rows.go rename to rows.go index 84ab136..f1f75f8 100644 --- a/impl/rows.go +++ b/rows.go @@ -1,4 +1,4 @@ -package impl +package sqldb // Rows is an interface with the methods of sql.Rows // that are needed for ScanSlice. diff --git a/rowscanner.go b/rowscanner.go index 0507364..3d1192a 100644 --- a/rowscanner.go +++ b/rowscanner.go @@ -1,5 +1,11 @@ package sqldb +import ( + "database/sql" + + "github.com/domonda/go-sqldb/reflection" +) + // RowScanner scans the values from a single row. type RowScanner interface { // Scan values of a row into dest variables, which must be passed as pointers. @@ -22,3 +28,124 @@ type RowScanner interface { // Columns returns the column names. Columns() ([]string, error) } + +var ( + _ RowScanner = &rowScanner{} + _ RowScanner = CurrentRowScanner{} + _ RowScanner = SingleRowScanner{} +) + +// rowScanner implements rowScanner for a sql.Row +type rowScanner struct { + rows Rows + structFieldNamer reflection.StructFieldMapper + query string // for error wrapping + argFmt string // for error wrapping + args []any // for error wrapping +} + +func NewRowScanner(rows Rows, structFieldNamer reflection.StructFieldMapper, query, argFmt string, args []any) *rowScanner { + return &rowScanner{rows, structFieldNamer, query, argFmt, args} +} + +func (s *rowScanner) Scan(dest ...any) (err error) { + defer func() { + err = combineErrors(err, s.rows.Close()) + err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) + }() + + if s.rows.Err() != nil { + return s.rows.Err() + } + if !s.rows.Next() { + if s.rows.Err() != nil { + return s.rows.Err() + } + return sql.ErrNoRows + } + + return s.rows.Scan(dest...) +} + +func (s *rowScanner) ScanStruct(dest any) (err error) { + defer func() { + err = combineErrors(err, s.rows.Close()) + err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) + }() + + if s.rows.Err() != nil { + return s.rows.Err() + } + if !s.rows.Next() { + if s.rows.Err() != nil { + return s.rows.Err() + } + return sql.ErrNoRows + } + + return reflection.ScanStruct(s.rows, dest, s.structFieldNamer) +} + +func (s *rowScanner) ScanValues() ([]any, error) { + return ScanValues(s.rows) +} + +func (s *rowScanner) ScanStrings() ([]string, error) { + return ScanStrings(s.rows) +} + +func (s *rowScanner) Columns() ([]string, error) { + return s.rows.Columns() +} + +// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close +type CurrentRowScanner struct { + Rows Rows + StructFieldMapper reflection.StructFieldMapper +} + +func (s CurrentRowScanner) Scan(dest ...any) error { + return s.Rows.Scan(dest...) +} + +func (s CurrentRowScanner) ScanStruct(dest any) error { + return reflection.ScanStruct(s.Rows, dest, s.StructFieldMapper) +} + +func (s CurrentRowScanner) ScanValues() ([]any, error) { + return ScanValues(s.Rows) +} + +func (s CurrentRowScanner) ScanStrings() ([]string, error) { + return ScanStrings(s.Rows) +} + +func (s CurrentRowScanner) Columns() ([]string, error) { + return s.Rows.Columns() +} + +// SingleRowScanner always uses the same Row +type SingleRowScanner struct { + Row Row + StructFieldMapper reflection.StructFieldMapper +} + +func (s SingleRowScanner) Scan(dest ...any) error { + return s.Row.Scan(dest...) +} + +func (s SingleRowScanner) ScanStruct(dest any) error { + return reflection.ScanStruct(s.Row, dest, s.StructFieldMapper) +} + +func (s SingleRowScanner) ScanValues() ([]any, error) { + return ScanValues(s.Row) +} + +func (s SingleRowScanner) ScanStrings() ([]string, error) { + return ScanStrings(s.Row) +} + +func (s SingleRowScanner) Columns() ([]string, error) { + return s.Row.Columns() +} diff --git a/rowsscanner.go b/rowsscanner.go index c318deb..645f621 100644 --- a/rowsscanner.go +++ b/rowsscanner.go @@ -1,5 +1,12 @@ package sqldb +import ( + "context" + "fmt" + + "github.com/domonda/go-sqldb/reflection" +) + // RowsScanner scans the values from multiple rows. type RowsScanner interface { // ScanSlice scans one value per row into one slice element of dest. @@ -43,3 +50,90 @@ type RowsScanner interface { // In case of zero rows, no error will be returned. ForEachRowCall(callback any) error } + +var _ RowsScanner = &rowsScanner{} + +// rowsScanner implements rowsScanner with Rows +type rowsScanner struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows Rows + structFieldNamer reflection.StructFieldMapper + query string // for error wrapping + argFmt string // for error wrapping + args []any // for error wrapping +} + +func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer reflection.StructFieldMapper, query, argFmt string, args []any) *rowsScanner { + return &rowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} +} + +func (s *rowsScanner) ScanSlice(dest any) error { + err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, nil) + if err != nil { + return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) + } + return nil +} + +func (s *rowsScanner) ScanStructSlice(dest any) error { + err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) + if err != nil { + return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) + } + return nil +} + +func (s *rowsScanner) Columns() ([]string, error) { + return s.rows.Columns() +} + +func (s *rowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { + cols, err := s.rows.Columns() + if err != nil { + return nil, err + } + if headerRow { + rows = [][]string{cols} + } + stringScannablePtrs := make([]any, len(cols)) + err = s.ForEachRow(func(rowScanner RowScanner) error { + row := make([]string, len(cols)) + for i := range stringScannablePtrs { + stringScannablePtrs[i] = (*StringScannable)(&row[i]) + } + err := rowScanner.Scan(stringScannablePtrs...) + if err != nil { + return err + } + rows = append(rows, row) + return nil + }) + return rows, err +} + +func (s *rowsScanner) ForEachRow(callback func(RowScanner) error) (err error) { + defer func() { + err = combineErrors(err, s.rows.Close()) + err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) + }() + + for s.rows.Next() { + if s.ctx.Err() != nil { + return s.ctx.Err() + } + + err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) + if err != nil { + return err + } + } + return s.rows.Err() +} + +func (s *rowsScanner) ForEachRowCall(callback any) error { + forEachRowFunc, err := forEachRowCallFunc(s.ctx, callback) + if err != nil { + return err + } + return s.ForEachRow(forEachRowFunc) +} diff --git a/impl/scanresult.go b/scanresult.go similarity index 87% rename from impl/scanresult.go rename to scanresult.go index c51c336..266297d 100644 --- a/impl/scanresult.go +++ b/scanresult.go @@ -1,6 +1,4 @@ -package impl - -import "github.com/domonda/go-sqldb" +package sqldb // ScanValues returns the values of a row exactly how they are // passed from the database driver to an sql.Scanner. @@ -11,7 +9,7 @@ func ScanValues(src Row) ([]any, error) { return nil, err } var ( - anys = make([]sqldb.AnyValue, len(cols)) + anys = make([]AnyValue, len(cols)) vals = make([]any, len(cols)) ) for i := range vals { @@ -41,7 +39,7 @@ func ScanStrings(src Row) ([]string, error) { args = make([]any, len(cols)) ) for i := range args { - args[i] = (*sqldb.StringScannable)(&strs[i]) + args[i] = (*StringScannable)(&strs[i]) } err = src.Scan(args...) if err != nil { From e76370b8f547acc7c64318a49689a235a26baff6 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 12:15:43 +0200 Subject: [PATCH 05/12] renamings --- mockconn/onetimerowsprovider.go | 4 ++-- mockconn/rowsprovider.go | 4 ++-- pqconn/connection.go | 24 ++++++++++++------------ pqconn/transaction.go | 24 ++++++++++++------------ reflection/scanslice.go | 10 +++++----- rowscanner.go | 16 ++++++++-------- rowsscanner.go | 20 ++++++++++---------- 7 files changed, 51 insertions(+), 51 deletions(-) diff --git a/mockconn/onetimerowsprovider.go b/mockconn/onetimerowsprovider.go index 11a6822..44c4b39 100644 --- a/mockconn/onetimerowsprovider.go +++ b/mockconn/onetimerowsprovider.go @@ -45,7 +45,7 @@ func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, que p.rowsScanners[key] = scanner } -func (p *OneTimeRowsProvider) QueryRow(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +func (p *OneTimeRowsProvider) QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { p.mtx.Lock() defer p.mtx.Unlock() @@ -55,7 +55,7 @@ func (p *OneTimeRowsProvider) QueryRow(structFieldNamer reflection.StructFieldMa return scanner } -func (p *OneTimeRowsProvider) QueryRows(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +func (p *OneTimeRowsProvider) QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/mockconn/rowsprovider.go b/mockconn/rowsprovider.go index 3c08384..a3074e6 100644 --- a/mockconn/rowsprovider.go +++ b/mockconn/rowsprovider.go @@ -6,6 +6,6 @@ import ( ) type RowsProvider interface { - QueryRow(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner - QueryRows(structFieldNamer reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner + QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner + QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner } diff --git a/pqconn/connection.go b/pqconn/connection.go index 9ca8c85..e309ab7 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -27,10 +27,10 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { return nil, err } return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, + ctx: ctx, + db: db, + config: config, + structFieldMapper: sqldb.DefaultStructFieldMapping, }, nil } @@ -48,10 +48,10 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer reflection.StructFieldMapper + ctx context.Context + db *sql.DB + config *sqldb.Config + structFieldMapper reflection.StructFieldMapper } func (conn *connection) clone() *connection { @@ -72,12 +72,12 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = mapper + c.structFieldMapper = mapper return c } func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldNamer + return conn.structFieldMapper } func (conn *connection) Ping(timeout time.Duration) error { @@ -129,7 +129,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, argFmt, args) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -138,7 +138,7 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, argFmt, args) } func (conn *connection) IsTransaction() bool { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 624e10b..098c125 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -12,18 +12,18 @@ import ( type transaction struct { // The parent non-transaction connection is needed // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldNamer reflection.StructFieldMapper + parent *connection + tx *sql.Tx + opts *sql.TxOptions + structFieldMapper reflection.StructFieldMapper } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldNamer: parent.structFieldNamer, + parent: parent, + tx: tx, + opts: opts, + structFieldMapper: parent.structFieldMapper, } } @@ -45,12 +45,12 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { func (conn *transaction) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { c := conn.clone() - c.structFieldNamer = mapper + c.structFieldMapper = mapper return c } func (conn *transaction) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldNamer + return conn.structFieldMapper } func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } @@ -88,7 +88,7 @@ func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, argFmt, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -97,7 +97,7 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, argFmt, args) } func (conn *transaction) IsTransaction() bool { diff --git a/reflection/scanslice.go b/reflection/scanslice.go index c4d7dcc..acac89a 100644 --- a/reflection/scanslice.go +++ b/reflection/scanslice.go @@ -18,10 +18,10 @@ var ( // ScanRowsAsSlice scans all srcRows as slice into dest. // The rows must either have only one column compatible with the element type of the slice, // or if multiple columns are returned then the slice element type must me a struct or struction pointer -// so that every column maps on exactly one struct field using structFieldNamer. -// In case of single column rows, nil must be passed for structFieldNamer. +// so that every column maps on exactly one struct field using structFieldMapper. +// In case of single column rows, nil must be passed for structFieldMapper. // ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer StructFieldMapper) error { +func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldMapper StructFieldMapper) error { defer srcRows.Close() destVal := reflect.ValueOf(dest) @@ -46,8 +46,8 @@ func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNam newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) target := newSlice.Index(newSlice.Len() - 1).Addr() - if structFieldNamer != nil { - err := ScanStruct(srcRows, target.Interface(), structFieldNamer) + if structFieldMapper != nil { + err := ScanStruct(srcRows, target.Interface(), structFieldMapper) if err != nil { return err } diff --git a/rowscanner.go b/rowscanner.go index 3d1192a..df6d792 100644 --- a/rowscanner.go +++ b/rowscanner.go @@ -37,15 +37,15 @@ var ( // rowScanner implements rowScanner for a sql.Row type rowScanner struct { - rows Rows - structFieldNamer reflection.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping + rows Rows + structFieldMapper reflection.StructFieldMapper + query string // for error wrapping + argFmt string // for error wrapping + args []any // for error wrapping } -func NewRowScanner(rows Rows, structFieldNamer reflection.StructFieldMapper, query, argFmt string, args []any) *rowScanner { - return &rowScanner{rows, structFieldNamer, query, argFmt, args} +func NewRowScanner(rows Rows, structFieldMapper reflection.StructFieldMapper, query, argFmt string, args []any) *rowScanner { + return &rowScanner{rows, structFieldMapper, query, argFmt, args} } func (s *rowScanner) Scan(dest ...any) (err error) { @@ -83,7 +83,7 @@ func (s *rowScanner) ScanStruct(dest any) (err error) { return sql.ErrNoRows } - return reflection.ScanStruct(s.rows, dest, s.structFieldNamer) + return reflection.ScanStruct(s.rows, dest, s.structFieldMapper) } func (s *rowScanner) ScanValues() ([]any, error) { diff --git a/rowsscanner.go b/rowsscanner.go index 645f621..09f4960 100644 --- a/rowsscanner.go +++ b/rowsscanner.go @@ -55,16 +55,16 @@ var _ RowsScanner = &rowsScanner{} // rowsScanner implements rowsScanner with Rows type rowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer reflection.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows Rows + structFieldMapper reflection.StructFieldMapper + query string // for error wrapping + argFmt string // for error wrapping + args []any // for error wrapping } -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer reflection.StructFieldMapper, query, argFmt string, args []any) *rowsScanner { - return &rowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} +func NewRowsScanner(ctx context.Context, rows Rows, structFieldMapper reflection.StructFieldMapper, query, argFmt string, args []any) *rowsScanner { + return &rowsScanner{ctx, rows, structFieldMapper, query, argFmt, args} } func (s *rowsScanner) ScanSlice(dest any) error { @@ -76,7 +76,7 @@ func (s *rowsScanner) ScanSlice(dest any) error { } func (s *rowsScanner) ScanStructSlice(dest any) error { - err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) + err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldMapper) if err != nil { return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) } @@ -122,7 +122,7 @@ func (s *rowsScanner) ForEachRow(callback func(RowScanner) error) (err error) { return s.ctx.Err() } - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) + err := callback(CurrentRowScanner{s.rows, s.structFieldMapper}) if err != nil { return err } From 4df1670ca253a1bc392d8058721f1acb0841f240 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 13:00:36 +0200 Subject: [PATCH 06/12] replace Connection.ArgFmt with ParamPlaceholder --- connection.go | 6 ++++-- connectionwitherror.go | 5 +++-- db/insert.go | 29 +++++++++++------------------ db/update.go | 16 +++++++--------- db/upsert.go | 5 ++--- errors.go | 4 ++-- errors_test.go | 4 ++-- format.go | 4 ++-- format_test.go | 6 +++--- mockconn/connection.go | 10 +++------- mockconn/singlerowprovider.go | 4 ++-- mysqlconn/config.go | 2 -- mysqlconn/connection.go | 16 +++++++--------- mysqlconn/transaction.go | 14 +++++++------- paramplaceholder.go | 23 +++++++++++++++++++++++ pqconn/connection.go | 16 +++++++--------- pqconn/transaction.go | 14 +++++++------- rowscanner.go | 8 ++++---- rowsscanner.go | 8 ++++---- 19 files changed, 100 insertions(+), 94 deletions(-) create mode 100644 paramplaceholder.go diff --git a/connection.go b/connection.go index fe5de96..d3ed0c9 100644 --- a/connection.go +++ b/connection.go @@ -52,8 +52,10 @@ type Connection interface { // column of the connection's database. ValidateColumnName(name string) error - // ArgFmt returns the format for SQL query arguments - ArgFmt() string + // ParamPlaceholder returns a parameter value placeholder + // for the parameter with the passed zero based index + // specific to the database type of the connection. + ParamPlaceholder(index int) string // Err returns any current error of the connection Err() error diff --git a/connectionwitherror.go b/connectionwitherror.go index 223c7dc..886b2bd 100644 --- a/connectionwitherror.go +++ b/connectionwitherror.go @@ -3,6 +3,7 @@ package sqldb import ( "context" "database/sql" + "fmt" "time" "github.com/domonda/go-sqldb/reflection" @@ -52,8 +53,8 @@ func (e connectionWithError) ValidateColumnName(name string) error { return e.err } -func (e connectionWithError) ArgFmt() string { - return "" +func (e connectionWithError) ParamPlaceholder(index int) string { + return fmt.Sprintf(":%d", index+1) } func (e connectionWithError) Err() error { diff --git a/db/insert.go b/db/insert.go index 4676fa7..634c580 100644 --- a/db/insert.go +++ b/db/insert.go @@ -12,8 +12,6 @@ import ( type Values = sqldb.Values -var WrapNonNilErrorWithQuery = sqldb.WrapNonNilErrorWithQuery - // Insert a new row into table using the values. func Insert(ctx context.Context, table string, values Values) error { if len(values) == 0 { @@ -21,15 +19,14 @@ func Insert(ctx context.Context, table string, values Values) error { } conn := Conn(ctx) - argFmt := conn.ArgFmt() names, vals := values.Sorted() b := strings.Builder{} - writeInsertQuery(&b, table, argFmt, names) + writeInsertQuery(&b, table, conn, names) query := b.String() err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } // InsertUnique inserts a new row into table using the passed values @@ -45,16 +42,15 @@ func InsertUnique(ctx context.Context, table string, values Values, onConflict s } conn := Conn(ctx) - argFmt := conn.ArgFmt() names, vals := values.Sorted() var query strings.Builder - writeInsertQuery(&query, table, argFmt, names) + writeInsertQuery(&query, table, conn, names) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) err = conn.QueryRow(query.String(), vals...).Scan(&inserted) err = sqldb.ReplaceErrNoRows(err, nil) - err = WrapNonNilErrorWithQuery(err, query.String(), argFmt, vals) + err = sqldb.WrapNonNilErrorWithQuery(err, query.String(), conn, vals) return inserted, err } @@ -66,10 +62,9 @@ func InsertReturning(ctx context.Context, table string, values Values, returning } conn := Conn(ctx) - argFmt := conn.ArgFmt() names, vals := values.Sorted() var query strings.Builder - writeInsertQuery(&query, table, argFmt, names) + writeInsertQuery(&query, table, conn, names) query.WriteString(" RETURNING ") query.WriteString(returning) return conn.QueryRow(query.String(), vals...) @@ -80,7 +75,6 @@ func InsertReturning(ctx context.Context, table string, values Values, returning // Optional ColumnFilter can be passed to ignore mapped columns. func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { conn := Conn(ctx) - argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) @@ -89,12 +83,12 @@ func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) + writeInsertQuery(&b, table, conn, columns) query := b.String() err = conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } // InsertUniqueStruct inserts a new row into table using the connection's @@ -104,7 +98,6 @@ func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio // and returns if a row was inserted. func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...reflection.ColumnFilter) (inserted bool, err error) { conn := Conn(ctx) - argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) @@ -117,17 +110,17 @@ func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, i } var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) + writeInsertQuery(&b, table, conn, columns) fmt.Fprintf(&b, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) query := b.String() err = conn.QueryRow(query, vals...).Scan(&inserted) err = sqldb.ReplaceErrNoRows(err, nil) - return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return inserted, sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { +func writeInsertQuery(w *strings.Builder, table string, argFmt sqldb.ParamPlaceholderFormatter, names []string) { fmt.Fprintf(w, `INSERT INTO %s(`, table) for i, name := range names { if i > 0 { @@ -142,7 +135,7 @@ func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) if i > 0 { w.WriteByte(',') } - fmt.Fprintf(w, argFmt, i+1) + w.WriteString(argFmt.ParamPlaceholder(i)) } w.WriteByte(')') } diff --git a/db/update.go b/db/update.go index da2b450..35fe9ae 100644 --- a/db/update.go +++ b/db/update.go @@ -19,10 +19,9 @@ func Update(ctx context.Context, table string, values sqldb.Values, where string } conn := Conn(ctx) - argFmt := conn.ArgFmt() - query, vals := buildUpdateQuery(table, values, where, argFmt, args) + query, vals := buildUpdateQuery(table, values, where, conn, args) err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 @@ -33,7 +32,7 @@ func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, } conn := Conn(ctx) - query, vals := buildUpdateQuery(table, values, where, conn.ArgFmt(), args) + query, vals := buildUpdateQuery(table, values, where, conn, args) query += " RETURNING " + returning return conn.QueryRow(query, vals...) } @@ -46,12 +45,12 @@ func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, } conn := Conn(ctx) - query, vals := buildUpdateQuery(table, values, where, conn.ArgFmt(), args) + query, vals := buildUpdateQuery(table, values, where, conn, args) query += " RETURNING " + returning return conn.QueryRows(query, vals...) } -func buildUpdateQuery(table string, values sqldb.Values, where, argFmt string, args []any) (string, []any) { +func buildUpdateQuery(table string, values sqldb.Values, where string, argFmt sqldb.ParamPlaceholderFormatter, args []any) (string, []any) { names, vals := values.Sorted() var query strings.Builder @@ -60,7 +59,7 @@ func buildUpdateQuery(table string, values sqldb.Values, where, argFmt string, a if i > 0 { query.WriteByte(',') } - fmt.Fprintf(&query, `"%s"=%s`, names[i], fmt.Sprintf(argFmt, 1+len(args)+i)) + fmt.Fprintf(&query, `"%s"=%s`, names[i], argFmt.ParamPlaceholder(len(args)+i)) } fmt.Fprintf(&query, ` WHERE %s`, where) @@ -79,7 +78,6 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } conn := Conn(ctx) - argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { @@ -119,7 +117,7 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio err = conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } func derefStruct(rowStruct any) (reflect.Value, error) { diff --git a/db/upsert.go b/db/upsert.go index 6b9e9fa..9ed6964 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -24,7 +24,6 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } conn := Conn(ctx) - argFmt := conn.ArgFmt() mapper := conn.StructFieldMapper() table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { @@ -38,7 +37,7 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) + writeInsertQuery(&b, table, conn, columns) b.WriteString(` ON CONFLICT(`) for i, pkCol := range pkCols { if i > 0 { @@ -64,5 +63,5 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio err = conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) } diff --git a/errors.go b/errors.go index c082b92..bbe2f87 100644 --- a/errors.go +++ b/errors.go @@ -60,7 +60,7 @@ const ( // WrapNonNilErrorWithQuery wraps non nil errors with a formatted query // if the error was not already wrapped with a query. // If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { +func WrapNonNilErrorWithQuery(err error, query string, argFmt ParamPlaceholderFormatter, args []any) error { var wrapped errWithQuery if err == nil || errors.As(err, &wrapped) { return err @@ -71,7 +71,7 @@ func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error type errWithQuery struct { err error query string - argFmt string + argFmt ParamPlaceholderFormatter args []any } diff --git a/errors_test.go b/errors_test.go index 9c9bed6..fa0f181 100644 --- a/errors_test.go +++ b/errors_test.go @@ -11,7 +11,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { type args struct { err error query string - argFmt string + argFmt ParamPlaceholderFormatter args []any } tests := []struct { @@ -25,7 +25,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { args: args{ err: sql.ErrNoRows, query: `SELECT * FROM table WHERE b = $2 and a = $1`, - argFmt: "$%d", + argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{1, "2"}, }, wantError: fmt.Sprintf("%s from query: %s", sql.ErrNoRows, `SELECT * FROM table WHERE b = '2' and a = 1`), diff --git a/format.go b/format.go index cb46958..22a080e 100644 --- a/format.go +++ b/format.go @@ -77,9 +77,9 @@ func FormatValue(val any) (string, error) { return fmt.Sprint(val), nil } -func FormatQuery(query, argFmt string, args ...any) string { +func FormatQuery(query string, argFmt ParamPlaceholderFormatter, args ...any) string { for i := len(args) - 1; i >= 0; i-- { - placeholder := fmt.Sprintf(argFmt, i+1) + placeholder := argFmt.ParamPlaceholder(i) value, err := FormatValue(args[i]) if err != nil { value = "FORMATERROR:" + err.Error() diff --git a/format_test.go b/format_test.go index d30e394..64b81a2 100644 --- a/format_test.go +++ b/format_test.go @@ -81,12 +81,12 @@ WHERE tests := []struct { name string query string - argFmt string + argFmt ParamPlaceholderFormatter args []any want string }{ - {name: "query1", query: query1, argFmt: "$%d", args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, - {name: "query2", query: query2, argFmt: "$%d", args: []any{"", 2, "3"}, want: query2formatted}, + {name: "query1", query: query1, argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, + {name: "query2", query: query2, argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{"", 2, "3"}, want: query2formatted}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/mockconn/connection.go b/mockconn/connection.go index 550b870..54e0573 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -11,7 +11,7 @@ import ( "github.com/domonda/go-sqldb/reflection" ) -var DefaultArgFmt = "$%d" +var DefaultParamPlaceholderFormatter = sqldb.NewParamPlaceholderFormatter("$%d", 1) func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ @@ -20,7 +20,6 @@ func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) listening: newBoolMap(), rowsProvider: rowsProvider, structFieldMapper: sqldb.DefaultStructFieldMapping, - argFmt: DefaultArgFmt, } } @@ -30,7 +29,6 @@ type connection struct { listening *boolMap rowsProvider RowsProvider structFieldMapper reflection.StructFieldMapper - argFmt string } func (conn *connection) Context() context.Context { return conn.ctx } @@ -45,7 +43,6 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { listening: conn.listening, rowsProvider: conn.rowsProvider, structFieldMapper: conn.structFieldMapper, - argFmt: conn.argFmt, } } @@ -56,7 +53,6 @@ func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMappe listening: conn.listening, rowsProvider: conn.rowsProvider, structFieldMapper: mapper, - argFmt: conn.argFmt, } } @@ -80,8 +76,8 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *connection) ArgFmt() string { - return conn.argFmt +func (*connection) ParamPlaceholder(index int) string { + return fmt.Sprintf("$%d", index+1) } func (conn *connection) Err() error { diff --git a/mockconn/singlerowprovider.go b/mockconn/singlerowprovider.go index 68cd8f0..0bbec5c 100644 --- a/mockconn/singlerowprovider.go +++ b/mockconn/singlerowprovider.go @@ -10,14 +10,14 @@ import ( // NewSingleRowProvider a RowsProvider implementation // with a single row that will be re-used for every query. func NewSingleRowProvider(row *Row) RowsProvider { - return &singleRowProvider{row: row, argFmt: DefaultArgFmt} + return &singleRowProvider{row: row, argFmt: DefaultParamPlaceholderFormatter} } // SingleRowProvider implements RowsProvider with a single Row // that will be re-used for every query. type singleRowProvider struct { row *Row - argFmt string + argFmt sqldb.ParamPlaceholderFormatter } func (p *singleRowProvider) QueryRow(mapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { diff --git a/mysqlconn/config.go b/mysqlconn/config.go index 3622647..18b48f5 100644 --- a/mysqlconn/config.go +++ b/mysqlconn/config.go @@ -2,8 +2,6 @@ package mysqlconn import "github.com/go-sql-driver/mysql" -const argFmt = "?" - type Config = mysql.Config // NewConfig creates a new Config and sets default values. diff --git a/mysqlconn/connection.go b/mysqlconn/connection.go index 9c0435d..ca09094 100644 --- a/mysqlconn/connection.go +++ b/mysqlconn/connection.go @@ -29,7 +29,6 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { db: db, config: config, structFieldMapper: sqldb.DefaultStructFieldMapping, - argFmt: argFmt, } return conn, nil } @@ -52,7 +51,6 @@ type connection struct { db *sql.DB config *sqldb.Config structFieldMapper reflection.StructFieldMapper - argFmt string } func (conn *connection) clone() *connection { @@ -103,8 +101,8 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *connection) ArgFmt() string { - return conn.argFmt +func (conn *connection) ParamPlaceholder(index int) string { + return "?" } func (conn *connection) Err() error { @@ -121,25 +119,25 @@ func (conn *connection) Now() (now time.Time, err error) { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn.argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn.argFmt, args) + return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn, args) } func (conn *connection) IsTransaction() bool { diff --git a/mysqlconn/transaction.go b/mysqlconn/transaction.go index 95d4a03..7e5c000 100644 --- a/mysqlconn/transaction.go +++ b/mysqlconn/transaction.go @@ -62,8 +62,8 @@ func (conn *transaction) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *transaction) ArgFmt() string { - return conn.parent.argFmt +func (conn *transaction) ParamPlaceholder(index int) string { + return conn.parent.ParamPlaceholder(index) } func (conn *transaction) Err() error { @@ -80,25 +80,25 @@ func (conn *transaction) Now() (now time.Time, err error) { func (conn *transaction) Exec(query string, args ...any) error { _, err := conn.tx.Exec(query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) } func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn.parent.argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn.parent.argFmt, args) + return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn, args) } func (conn *transaction) IsTransaction() bool { diff --git a/paramplaceholder.go b/paramplaceholder.go new file mode 100644 index 0000000..feecdb6 --- /dev/null +++ b/paramplaceholder.go @@ -0,0 +1,23 @@ +package sqldb + +import "fmt" + +type ParamPlaceholderFormatter interface { + // ParamPlaceholder returns a parameter value placeholder + // for the parameter with the passed zero based index + // specific to the database type of the connection. + ParamPlaceholder(index int) string +} + +func NewParamPlaceholderFormatter(format string, indexOffset int) ParamPlaceholderFormatter { + return ¶mPlaceholderFormatter{format, indexOffset} +} + +type paramPlaceholderFormatter struct { + format string + indexOffset int +} + +func (f *paramPlaceholderFormatter) ParamPlaceholder(index int) string { + return fmt.Sprintf(f.format, index+f.indexOffset) +} diff --git a/pqconn/connection.go b/pqconn/connection.go index e309ab7..687d026 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -10,8 +10,6 @@ import ( "github.com/domonda/go-sqldb/reflection" ) -const argFmt = "$%d" - // New creates a new sqldb.Connection using the passed sqldb.Config // and github.com/lib/pq as driver implementation. // The connection is pinged with the passed context @@ -102,8 +100,8 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (*connection) ArgFmt() string { - return argFmt +func (*connection) ParamPlaceholder(index int) string { + return fmt.Sprintf("$%d", index+1) } func (conn *connection) Err() error { @@ -120,25 +118,25 @@ func (conn *connection) Now() (now time.Time, err error) { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, argFmt, args) + return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn, args) } func (conn *connection) IsTransaction() bool { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 098c125..fd84ae5 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -61,8 +61,8 @@ func (conn *transaction) ValidateColumnName(name string) error { return validateColumnName(name) } -func (*transaction) ArgFmt() string { - return argFmt +func (conn *transaction) ParamPlaceholder(index int) string { + return conn.parent.ParamPlaceholder(index) } func (conn *transaction) Err() error { @@ -79,25 +79,25 @@ func (conn *transaction) Now() (now time.Time, err error) { func (conn *transaction) Exec(query string, args ...any) error { _, err := conn.tx.Exec(query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) } func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowScannerWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, argFmt, args) + return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) } func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, argFmt, args) + err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) return sqldb.RowsScannerWithError(err) } - return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, argFmt, args) + return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn, args) } func (conn *transaction) IsTransaction() bool { diff --git a/rowscanner.go b/rowscanner.go index df6d792..7a754a3 100644 --- a/rowscanner.go +++ b/rowscanner.go @@ -39,12 +39,12 @@ var ( type rowScanner struct { rows Rows structFieldMapper reflection.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping + query string // for error wrapping + argFmt ParamPlaceholderFormatter // for error wrapping + args []any // for error wrapping } -func NewRowScanner(rows Rows, structFieldMapper reflection.StructFieldMapper, query, argFmt string, args []any) *rowScanner { +func NewRowScanner(rows Rows, structFieldMapper reflection.StructFieldMapper, query string, argFmt ParamPlaceholderFormatter, args []any) *rowScanner { return &rowScanner{rows, structFieldMapper, query, argFmt, args} } diff --git a/rowsscanner.go b/rowsscanner.go index 09f4960..d5e4c74 100644 --- a/rowsscanner.go +++ b/rowsscanner.go @@ -58,12 +58,12 @@ type rowsScanner struct { ctx context.Context // ctx is checked for every row and passed through to callbacks rows Rows structFieldMapper reflection.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping + query string // for error wrapping + argFmt ParamPlaceholderFormatter // for error wrapping + args []any // for error wrapping } -func NewRowsScanner(ctx context.Context, rows Rows, structFieldMapper reflection.StructFieldMapper, query, argFmt string, args []any) *rowsScanner { +func NewRowsScanner(ctx context.Context, rows Rows, structFieldMapper reflection.StructFieldMapper, query string, argFmt ParamPlaceholderFormatter, args []any) *rowsScanner { return &rowsScanner{ctx, rows, structFieldMapper, query, argFmt, args} } From 0155608f8e55b306ad351da33842ee31d18ed68b Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 14:40:26 +0200 Subject: [PATCH 07/12] added db.QueryStructSlice --- db/query.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/db/query.go b/db/query.go index a0ad4ec..93e412d 100644 --- a/db/query.go +++ b/db/query.go @@ -118,3 +118,12 @@ func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, col } return table, columns, nil } + +// QueryStructSlice returns queried rows as slice of pointers to the generic struct type S +func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []*S, err error) { + err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) + if err != nil { + return nil, err + } + return rows, nil +} From 958a5d072d5d8ee4749d9461038d6c91721889bc Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 15 Jun 2022 14:46:25 +0200 Subject: [PATCH 08/12] QueryStructSlice --- db/query.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/db/query.go b/db/query.go index 93e412d..4eae184 100644 --- a/db/query.go +++ b/db/query.go @@ -119,8 +119,9 @@ func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, col return table, columns, nil } -// QueryStructSlice returns queried rows as slice of pointers to the generic struct type S -func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []*S, err error) { +// QueryStructSlice returns queried rows as slice of the generic type S +// which must be a struct or a pointer to a struct. +func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) if err != nil { return nil, err From 6abd72b2262bce26614553962dbab4b007584337 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 17 Jun 2022 14:24:42 +0200 Subject: [PATCH 09/12] added db.QueryValueOrDefault --- db/query.go | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/db/query.go b/db/query.go index 4eae184..56e63b0 100644 --- a/db/query.go +++ b/db/query.go @@ -2,6 +2,7 @@ package db import ( "context" + "database/sql" "errors" "fmt" "reflect" @@ -34,10 +35,27 @@ func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner } // QueryValue queries a single value of type T. -func QueryValue[T any](ctx context.Context, query string, args ...any) (T, error) { - var val T - err := Conn(ctx).QueryRow(query, args...).Scan(&val) - return val, err +func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { + err = Conn(ctx).QueryRow(query, args...).Scan(&value) + if err != nil { + var zero T + return zero, err + } + return value, nil +} + +// QueryValueOrDefault queries a single value of type T +// or returns the default zero value of T in case of sql.ErrNoRows. +func QueryValueOrDefault[T any](ctx context.Context, query string, args ...any) (value T, err error) { + err = Conn(ctx).QueryRow(query, args...).Scan(&value) + if err != nil { + var zero T + if errors.Is(err, sql.ErrNoRows) { + return zero, nil + } + return zero, err + } + return value, err } // QueryStruct uses the passed pkValues to query a table row @@ -79,7 +97,13 @@ func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error // passed pkValues. func QueryStructOrNil[S any](ctx context.Context, pkValues ...any) (row *S, err error) { row, err = QueryStruct[S](ctx, pkValues...) - return row, ReplaceErrNoRows(err, nil) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return row, nil } func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { From 8dd37744ae18d6dbbb65aab73f39b538dc949b52 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Wed, 22 Jun 2022 17:39:33 +0200 Subject: [PATCH 10/12] work in progress --- config.go | 68 ++++-- connection.go | 61 ++---- connectionwitherror.go | 206 ------------------ db/config.go | 8 +- db/conn.go | 28 +-- db/insert.go | 62 +++--- db/query.go | 29 ++- db/update.go | 15 +- db/upsert.go | 5 +- dbconnection.go | 101 +++++++++ errconnection.go | 88 ++++++++ errors.go | 51 ++--- errors_test.go | 2 +- examples/user_demo/user_demo.go | 2 +- mockconn/connection.go | 58 ++--- mockconn/onetimerowsprovider.go | 16 +- mockconn/rowsprovider.go | 5 +- mockconn/singlerowprovider.go | 9 +- mysqlconn/connection.go | 43 ++-- mysqlconn/transaction.go | 43 ++-- pqconn/connection.go | 43 ++-- pqconn/transaction.go | 45 ++-- foreachrow.go => reflection/foreachrow.go | 8 +- .../foreachrow_test.go | 2 +- reflection/scan.go | 91 ++++++++ .../{scanstruct_test.go => scan_test.go} | 0 reflection/scanslice.go | 140 +++++------- reflection/scanstruct.go | 5 +- row.go | 84 +++++++ rows.go | 105 ++++++--- rowscanner.go | 151 ------------- rowsscanner.go | 139 ------------ transaction.go | 17 +- txconnection.go | 89 ++++++++ 34 files changed, 866 insertions(+), 953 deletions(-) delete mode 100644 connectionwitherror.go create mode 100644 dbconnection.go create mode 100644 errconnection.go rename foreachrow.go => reflection/foreachrow.go (93%) rename foreachrow_test.go => reflection/foreachrow_test.go (97%) create mode 100644 reflection/scan.go rename reflection/{scanstruct_test.go => scan_test.go} (100%) delete mode 100644 rowscanner.go delete mode 100644 rowsscanner.go create mode 100644 txconnection.go diff --git a/config.go b/config.go index 3e1ab65..8694dce 100644 --- a/config.go +++ b/config.go @@ -3,35 +3,53 @@ package sqldb import ( "context" "database/sql" + "errors" "fmt" "net/url" "time" - - "github.com/domonda/go-sqldb/reflection" ) -// DefaultStructFieldMapping provides the default StructFieldTagNaming -// using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. -// Implements StructFieldMapper. -var DefaultStructFieldMapping = reflection.NewTaggedStructFieldMapping() - // Config for a connection. // For tips see https://www.alexedwards.net/blog/configuring-sqldb type Config struct { - Driver string `json:"driver"` - Host string `json:"host"` - Port uint16 `json:"port,omitempty"` - User string `json:"user,omitempty"` - Password string `json:"password,omitempty"` - Database string `json:"database"` - Extra map[string]string `json:"misc,omitempty"` - MaxOpenConns int `json:"maxOpenConns,omitempty"` - MaxIdleConns int `json:"maxIdleConns,omitempty"` - ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` + Driver string `json:"driver"` + Host string `json:"host"` + Port uint16 `json:"port,omitempty"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + Extra map[string]string `json:"misc,omitempty"` + MaxOpenConns int `json:"maxOpenConns,omitempty"` + MaxIdleConns int `json:"maxIdleConns,omitempty"` + ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` + + // ValidateColumnName returns an error + // if the passed name is not valid for a + // column of the connection's database. + ValidateColumnName func(name string) error `json:"-"` + + // ParamPlaceholder returns a parameter value placeholder + // for the parameter with the passed zero based index + // specific to the database type of the connection. + ParamPlaceholderFormatter `json:"-"` + DefaultIsolationLevel sql.IsolationLevel `json:"-"` - Err error `json:"-"` + + // Err will be returned from Connection.Err() + Err error `json:"-"` } +// func (c *DBConnection) ValidateColumnName(name string) error { +// if name == "" { +// return errors.New("empty column name") +// } +// return nil +// } + +// func (c *DBConnection) ParamPlaceholder(index int) string { +// return fmt.Sprintf(":%d", index+1) +// } + // Validate returns Config.Err if it is not nil // or an error if the Config does not have // a Driver, Host, or Database. @@ -39,19 +57,25 @@ func (c *Config) Validate() error { if c.Err != nil { return c.Err } + if c.ValidateColumnName == nil { + return errors.New("missing sqldb.Config.ValidateColumnName") + } + if c.ParamPlaceholderFormatter == nil { + return errors.New("missing sqldb.Config.ParamPlaceholderFormatter") + } if c.Driver == "" { - return fmt.Errorf("missing sqldb.Config.Driver") + return errors.New("missing sqldb.Config.Driver") } if c.Host == "" { - return fmt.Errorf("missing sqldb.Config.Host") + return errors.New("missing sqldb.Config.Host") } if c.Database == "" { - return fmt.Errorf("missing sqldb.Config.Database") + return errors.New("missing sqldb.Config.Database") } return nil } -// ConnectURL for connecting to a database +// ConnectURL returns a connection URL for the Config func (c *Config) ConnectURL() string { extra := make(url.Values) for key, val := range c.Extra { diff --git a/connection.go b/connection.go index d3ed0c9..e2bfd26 100644 --- a/connection.go +++ b/connection.go @@ -4,8 +4,6 @@ import ( "context" "database/sql" "time" - - "github.com/domonda/go-sqldb/reflection" ) type ( @@ -18,75 +16,44 @@ type ( // Connection represents a database connection or transaction type Connection interface { - // Context that all connection operations use. - // See also WithContext. - Context() context.Context - - // WithContext returns a connection that uses the passed - // context for its operations. - WithContext(ctx context.Context) Connection - - // WithStructFieldMapper returns a copy of the connection - // that will use the passed reflection.StructFieldMapper. - WithStructFieldMapper(reflection.StructFieldMapper) Connection + // Config returns the configuration used + // to create this connection. + Config() *Config - // StructFieldMapper used by methods of this Connection. - StructFieldMapper() reflection.StructFieldMapper + // Stats returns the sql.DBStats of this connection. + Stats() sql.DBStats // Ping returns an error if the database // does not answer on this connection // with an optional timeout. // The passed timeout has to be greater zero // to be considered. - Ping(timeout time.Duration) error - - // Stats returns the sql.DBStats of this connection. - Stats() sql.DBStats - - // Config returns the configuration used - // to create this connection. - Config() *Config - - // ValidateColumnName returns an error - // if the passed name is not valid for a - // column of the connection's database. - ValidateColumnName(name string) error - - // ParamPlaceholder returns a parameter value placeholder - // for the parameter with the passed zero based index - // specific to the database type of the connection. - ParamPlaceholder(index int) string + Ping(ctx context.Context, timeout time.Duration) error // Err returns any current error of the connection Err() error - // Now returns the result of the SQL now() - // function for the current connection. - // Useful for getting the timestamp of a - // SQL transaction for use in Go code. - Now() (time.Time, error) - // Exec executes a query with optional args. - Exec(query string, args ...any) error + Exec(ctx context.Context, query string, args ...any) error // QueryRow queries a single row and returns a RowScanner for the results. - QueryRow(query string, args ...any) RowScanner + QueryRow(ctx context.Context, query string, args ...any) Row // QueryRows queries multiple rows and returns a RowsScanner for the results. - QueryRows(query string, args ...any) RowsScanner + QueryRows(ctx context.Context, query string, args ...any) Rows // IsTransaction returns if the connection is a transaction IsTransaction() bool - // TransactionOptions returns the sql.TxOptions of the - // current transaction and true as second result value, - // or false if the connection is not a transaction. - TransactionOptions() (*sql.TxOptions, bool) + // TxOptions returns the sql.TxOptions of the + // current transaction which can be nil for the default options. + // Use IsTransaction to check if the connection is a transaction. + TxOptions() *sql.TxOptions // Begin a new transaction. // If the connection is already a transaction, a brand // new transaction will begin on the parent's connection. - Begin(opts *sql.TxOptions) (Connection, error) + Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) // Commit the current transaction. // Returns ErrNotWithinTransaction if the connection diff --git a/connectionwitherror.go b/connectionwitherror.go deleted file mode 100644 index 886b2bd..0000000 --- a/connectionwitherror.go +++ /dev/null @@ -1,206 +0,0 @@ -package sqldb - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/domonda/go-sqldb/reflection" -) - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { - if err == nil { - panic("ConnectionWithError needs an error") - } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(reflection.StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() reflection.StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) ParamPlaceholder(index int) string { - return fmt.Sprintf(":%d", index+1) -} - -func (e connectionWithError) Err() error { - return e.err -} - -func (e connectionWithError) Now() (time.Time, error) { - return time.Time{}, e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions) (Connection, error) { - return nil, e.err -} - -func (e connectionWithError) Commit() error { - return e.err -} - -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err -} - -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err -} diff --git a/db/config.go b/db/config.go index bd6c1b4..793ef87 100644 --- a/db/config.go +++ b/db/config.go @@ -1,21 +1,25 @@ package db import ( - "context" "errors" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) var ( // Number of retries used for a SerializedTransaction // before it fails SerializedTransactionRetries = 10 + + // DefaultStructFieldMapping provides the default StructFieldTagNaming + // using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. + // Implements StructFieldMapper. + DefaultStructFieldMapping = reflection.NewTaggedStructFieldMapping() ) var ( globalConn = sqldb.ConnectionWithError( - context.Background(), errors.New("database connection not initialized"), ) globalConnCtxKey int diff --git a/db/conn.go b/db/conn.go index c43dbf8..770db11 100644 --- a/db/conn.go +++ b/db/conn.go @@ -22,24 +22,24 @@ func SetConn(c sqldb.Connection) { // The returned connection will use the passed context. // See sqldb.Connection.WithContext func Conn(ctx context.Context) sqldb.Connection { - return ConnDefault(ctx, globalConn) -} - -// ConnDefault returns a non nil sqldb.Connection from ctx -// or the passed defaultConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { - c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) - if c == nil { - c = defaultConn - } - if c.Context() == ctx { + if c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection); c != nil { return c } - return c.WithContext(ctx) + return globalConn } +// // ConnDefault returns a non nil sqldb.Connection from ctx +// // or the passed defaultConn. +// // The returned connection will use the passed context. +// // See sqldb.Connection.WithContext +// func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { +// c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) +// if c == nil { +// return defaultConn +// } +// return c +// } + // ContextWithConn returns a new context with the passed sqldb.Connection // added as value so it can be retrieved again using Conn(ctx). // Passing a nil connection causes Conn(ctx) diff --git a/db/insert.go b/db/insert.go index 634c580..a097382 100644 --- a/db/insert.go +++ b/db/insert.go @@ -19,14 +19,17 @@ func Insert(ctx context.Context, table string, values Values) error { } conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() b := strings.Builder{} - writeInsertQuery(&b, table, conn, names) + writeInsertQuery(&b, table, argFmt, names) query := b.String() - err := conn.Exec(query, vals...) - - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + err := conn.Exec(ctx, query, vals...) + if err != nil { + return sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return nil } // InsertUnique inserts a new row into table using the passed values @@ -42,32 +45,34 @@ func InsertUnique(ctx context.Context, table string, values Values, onConflict s } conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() var query strings.Builder - writeInsertQuery(&query, table, conn, names) + writeInsertQuery(&query, table, argFmt, names) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - - err = sqldb.ReplaceErrNoRows(err, nil) - err = sqldb.WrapNonNilErrorWithQuery(err, query.String(), conn, vals) + err = conn.QueryRow(ctx, query.String(), vals...).Scan(&inserted) + if err != nil { + return false, sqldb.WrapErrorWithQuery(err, query.String(), vals, argFmt) + } return inserted, err } // InsertReturning inserts a new row into table using values // and returns values from the inserted row listed in returning. -func InsertReturning(ctx context.Context, table string, values Values, returning string) sqldb.RowScanner { +func InsertReturning(ctx context.Context, table string, values Values, returning string) sqldb.Row { if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) + return sqldb.RowWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) } conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() var query strings.Builder - writeInsertQuery(&query, table, conn, names) + writeInsertQuery(&query, table, argFmt, names) query.WriteString(" RETURNING ") query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) + return conn.QueryRow(ctx, query.String(), vals...) } // InsertStruct inserts a new row into table using the connection's @@ -75,20 +80,21 @@ func InsertReturning(ctx context.Context, table string, values Values, returning // Optional ColumnFilter can be passed to ignore mapped columns. func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { conn := Conn(ctx) - mapper := conn.StructFieldMapper() - - table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) + table, columns, vals, err := insertStructValues(rowStruct, DefaultStructFieldMapping, ignoreColumns) if err != nil { return err } + argFmt := conn.Config().ParamPlaceholderFormatter var b strings.Builder - writeInsertQuery(&b, table, conn, columns) + writeInsertQuery(&b, table, argFmt, columns) query := b.String() - err = conn.Exec(query, vals...) - - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + err = conn.Exec(ctx, query, vals...) + if err != nil { + return sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return nil } // InsertUniqueStruct inserts a new row into table using the connection's @@ -98,9 +104,7 @@ func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio // and returns if a row was inserted. func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...reflection.ColumnFilter) (inserted bool, err error) { conn := Conn(ctx) - mapper := conn.StructFieldMapper() - - table, columns, vals, err := insertStructValues(rowStruct, mapper, ignoreColumns) + table, columns, vals, err := insertStructValues(rowStruct, DefaultStructFieldMapping, ignoreColumns) if err != nil { return false, err } @@ -109,15 +113,17 @@ func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, i onConflict = onConflict[1 : len(onConflict)-1] } + argFmt := conn.Config().ParamPlaceholderFormatter var b strings.Builder - writeInsertQuery(&b, table, conn, columns) + writeInsertQuery(&b, table, argFmt, columns) fmt.Fprintf(&b, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) query := b.String() - err = conn.QueryRow(query, vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) - - return inserted, sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + err = conn.QueryRow(ctx, query, vals...).Scan(&inserted) + if err != nil { + return false, sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return inserted, nil } func writeInsertQuery(w *strings.Builder, table string, argFmt sqldb.ParamPlaceholderFormatter, names []string) { diff --git a/db/query.go b/db/query.go index 56e63b0..03137a4 100644 --- a/db/query.go +++ b/db/query.go @@ -16,27 +16,32 @@ import ( // Useful for getting the timestamp of a // SQL transaction for use in Go code. func Now(ctx context.Context) (time.Time, error) { - return Conn(ctx).Now() + var now time.Time + err := Conn(ctx).QueryRow(ctx, `SELECT now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } // Exec executes a query with optional args. func Exec(ctx context.Context, query string, args ...any) error { - return Conn(ctx).Exec(query, args...) + return Conn(ctx).Exec(ctx, query, args...) } -// QueryRow queries a single row and returns a RowScanner for the results. -func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner { - return Conn(ctx).QueryRow(query, args...) +// QueryRow queries a single row and returns a Row for the results. +func QueryRow(ctx context.Context, query string, args ...any) sqldb.Row { + return Conn(ctx).QueryRow(ctx, query, args...) } -// QueryRows queries multiple rows and returns a RowsScanner for the results. -func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner { - return Conn(ctx).QueryRows(query, args...) +// QueryRows queries multiple rows and returns a Rows for the results. +func QueryRows(ctx context.Context, query string, args ...any) sqldb.Rows { + return Conn(ctx).QueryRows(ctx, query, args...) } // QueryValue queries a single value of type T. func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) + err = Conn(ctx).QueryRow(ctx, query, args...).Scan(&value) if err != nil { var zero T return zero, err @@ -47,7 +52,7 @@ func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, // QueryValueOrDefault queries a single value of type T // or returns the default zero value of T in case of sql.ErrNoRows. func QueryValueOrDefault[T any](ctx context.Context, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) + err = Conn(ctx).QueryRow(ctx, query, args...).Scan(&value) if err != nil { var zero T if errors.Is(err, sql.ErrNoRows) { @@ -82,7 +87,7 @@ func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error for i := 1; i < len(pkColumns); i++ { query += fmt.Sprintf(` AND "%s" = $%d`, pkColumns[i], i+1) } - err = conn.QueryRow(query, pkValues...).ScanStruct(&row) + err = conn.QueryRow(ctx, query, pkValues...).ScanStruct(&row) if err != nil { return nil, err } @@ -146,7 +151,7 @@ func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, col // QueryStructSlice returns queried rows as slice of the generic type S // which must be a struct or a pointer to a struct. func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { - err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) + err = Conn(ctx).QueryRows(ctx, query, args...).ScanStructSlice(&rows) if err != nil { return nil, err } diff --git a/db/update.go b/db/update.go index 35fe9ae..1cdd1f5 100644 --- a/db/update.go +++ b/db/update.go @@ -21,14 +21,14 @@ func Update(ctx context.Context, table string, values sqldb.Values, where string conn := Conn(ctx) query, vals := buildUpdateQuery(table, values, where, conn, args) err := conn.Exec(query, vals...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + return sqldb.WrapErrorWithQuery(err, query, conn, vals) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 // and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.Row { if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) + return sqldb.RowWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) } conn := Conn(ctx) @@ -39,9 +39,9 @@ func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 // and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.Rows { if len(values) == 0 { - return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) + return sqldb.RowsWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) } conn := Conn(ctx) @@ -78,8 +78,7 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } conn := Conn(ctx) - mapper := conn.StructFieldMapper() - table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, DefaultStructFieldMapping, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { return err } @@ -117,7 +116,7 @@ func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio err = conn.Exec(query, vals...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + return sqldb.WrapErrorWithQuery(err, query, conn, vals) } func derefStruct(rowStruct any) (reflect.Value, error) { diff --git a/db/upsert.go b/db/upsert.go index 9ed6964..5cc38ce 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -24,8 +24,7 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio } conn := Conn(ctx) - mapper := conn.StructFieldMapper() - table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, DefaultStructFieldMapping, append(ignoreColumns, sqldb.IgnoreReadOnly)) if err != nil { return err } @@ -63,5 +62,5 @@ func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflectio err = conn.Exec(query, vals...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, vals) + return sqldb.WrapErrorWithQuery(err, query, conn, vals) } diff --git a/dbconnection.go b/dbconnection.go new file mode 100644 index 0000000..b6836fc --- /dev/null +++ b/dbconnection.go @@ -0,0 +1,101 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +type DBConnection struct { + Conf *Config + DB *sql.DB +} + +func (c *DBConnection) Config() *Config { + return c.Conf +} + +func (c *DBConnection) Stats() sql.DBStats { + return c.DB.Stats() +} + +func (c *DBConnection) Ping(ctx context.Context, timeout time.Duration) error { + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return c.DB.PingContext(ctx) +} + +func (c *DBConnection) Err() error { + return c.Conf.Err +} + +func (c *DBConnection) Exec(ctx context.Context, query string, args ...any) error { + _, err := c.DB.ExecContext(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter) + } + return nil +} + +func (c *DBConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + rows, err := c.DB.QueryContext(ctx, query, args...) + if err != nil { + return RowWithError(WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter)) + } + return NewRow(ctx, rows, c, query, args) +} + +func (c *DBConnection) QueryRows(ctx context.Context, query string, args ...any) Rows { + rows, err := c.DB.QueryContext(ctx, query, args...) + if err != nil { + return RowsWithError(WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter)) + } + return NewRows(ctx, rows, c, query, args) +} + +func (c *DBConnection) IsTransaction() bool { + return false +} + +func (c *DBConnection) TxOptions() *sql.TxOptions { + return nil +} + +func (c *DBConnection) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + tx, err := c.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &TxConnection{ + Parent: c, + Tx: tx, + Opts: opts, + }, nil +} + +func (c *DBConnection) Commit() error { + return ErrNotWithinTransaction +} + +func (c *DBConnection) Rollback() error { + return ErrNotWithinTransaction +} + +func (c *DBConnection) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return ErrNotSupported +} + +func (c *DBConnection) UnlistenChannel(channel string) error { + return ErrNotSupported +} + +func (c *DBConnection) IsListeningOnChannel(channel string) bool { + return false +} + +func (c *DBConnection) Close() error { + return c.DB.Close() +} diff --git a/errconnection.go b/errconnection.go new file mode 100644 index 0000000..e94d10c --- /dev/null +++ b/errconnection.go @@ -0,0 +1,88 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +// ConnectionWithError returns a dummy Connection +// where all methods return the passed error. +func ConnectionWithError(err error) Connection { + if err == nil { + panic("ConnectionWithError needs an error") + } + return errConn{err} +} + +type errConn struct { + err error +} + +func (e errConn) Config() *Config { + return &Config{Err: e.err} +} + +func (e errConn) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (e errConn) Ping(context.Context, time.Duration) error { + return e.err +} + +func (e errConn) Err() error { + return e.err +} + +func (e errConn) Exec(ctx context.Context, query string, args ...any) error { + return e.err +} + +func (e errConn) QueryRow(ctx context.Context, query string, args ...any) Row { + return RowWithError(e.err) +} + +func (e errConn) QueryRows(ctx context.Context, query string, args ...any) Rows { + return RowsWithError(e.err) +} + +func (e errConn) IsTransaction() bool { + return false +} + +func (ce errConn) TxOptions() *sql.TxOptions { + return nil +} + +func (e errConn) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + return nil, e.err +} + +func (e errConn) Commit() error { + return e.err +} + +func (e errConn) Rollback() error { + return e.err +} + +func (e errConn) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { + return e.err +} + +func (e errConn) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return e.err +} + +func (e errConn) UnlistenChannel(channel string) error { + return e.err +} + +func (e errConn) IsListeningOnChannel(channel string) bool { + return false +} + +func (e errConn) Close() error { + return e.err +} diff --git a/errors.go b/errors.go index bbe2f87..09bee08 100644 --- a/errors.go +++ b/errors.go @@ -6,11 +6,17 @@ import ( "fmt" ) -var ( - _ Connection = connectionWithError{} - _ RowScanner = rowScannerWithError{} - _ RowsScanner = rowsScannerWithError{} -) +func combineTwoErrors(prim, sec error) error { + switch { + case prim != nil && sec != nil: + return fmt.Errorf("%w\n%s", prim, sec) + case prim != nil: + return prim + case sec != nil: + return sec + } + return nil +} // ReplaceErrNoRows returns the passed replacement error // if errors.Is(err, sql.ErrNoRows), @@ -57,40 +63,29 @@ const ( ErrNotSupported sentinelError = "not supported" ) -// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query +// WrapErrorWithQuery wraps non nil errors with a formatted query // if the error was not already wrapped with a query. // If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query string, argFmt ParamPlaceholderFormatter, args []any) error { +func WrapErrorWithQuery(err error, query string, args []any, paramFmt ParamPlaceholderFormatter) error { + if err == nil { + return nil + } var wrapped errWithQuery - if err == nil || errors.As(err, &wrapped) { + if errors.As(err, &wrapped) { return err } - return errWithQuery{err, query, argFmt, args} + return errWithQuery{err, query, args, paramFmt} } type errWithQuery struct { - err error - query string - argFmt ParamPlaceholderFormatter - args []any + err error + query string + args []any + paramFmt ParamPlaceholderFormatter } func (e errWithQuery) Unwrap() error { return e.err } func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) + return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.paramFmt, e.args...)) } - -func combineErrors(prim, sec error) error { - switch { - case prim != nil && sec != nil: - return fmt.Errorf("%w\n%s", prim, sec) - case prim != nil: - return prim - case sec != nil: - return sec - } - return nil -} - -// ConnectionWithError diff --git a/errors_test.go b/errors_test.go index fa0f181..a32b2c2 100644 --- a/errors_test.go +++ b/errors_test.go @@ -33,7 +33,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := WrapNonNilErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) + err := WrapErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) if tt.wantError == "" && err != nil || tt.wantError != "" && (err == nil || err.Error() != tt.wantError) { t.Errorf("WrapNonNilErrorWithQuery() error = %v, wantErr %v", err, tt.wantError) } diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index 9f0a17f..a329ab1 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -65,7 +65,7 @@ func main() { } err = conn.QueryRows(`select name, email from public.user`).ForEachRow( - func(row sqldb.RowScanner) error { + func(row sqldb.Row) error { var name, email string err := row.Scan(&name, &email) if err != nil { diff --git a/mockconn/connection.go b/mockconn/connection.go index 54e0573..f42b4cd 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -8,27 +8,24 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) var DefaultParamPlaceholderFormatter = sqldb.NewParamPlaceholderFormatter("$%d", 1) func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ - ctx: ctx, - queryWriter: queryWriter, - listening: newBoolMap(), - rowsProvider: rowsProvider, - structFieldMapper: sqldb.DefaultStructFieldMapping, + ctx: ctx, + queryWriter: queryWriter, + listening: newBoolMap(), + rowsProvider: rowsProvider, } } type connection struct { - ctx context.Context - queryWriter io.Writer - listening *boolMap - rowsProvider RowsProvider - structFieldMapper reflection.StructFieldMapper + ctx context.Context + queryWriter io.Writer + listening *boolMap + rowsProvider RowsProvider } func (conn *connection) Context() context.Context { return conn.ctx } @@ -38,28 +35,13 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return conn } return &connection{ - ctx: ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldMapper: conn.structFieldMapper, + ctx: ctx, + queryWriter: conn.queryWriter, + listening: conn.listening, + rowsProvider: conn.rowsProvider, } } -func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { - return &connection{ - ctx: conn.ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldMapper: mapper, - } -} - -func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldMapper -} - func (conn *connection) Stats() sql.DBStats { return sql.DBStats{} } @@ -95,30 +77,30 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { if conn.ctx.Err() != nil { - return sqldb.RowScannerWithError(conn.ctx.Err()) + return sqldb.RowWithError(conn.ctx.Err()) } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, query) } if conn.rowsProvider == nil { - return sqldb.RowScannerWithError(nil) + return sqldb.RowWithError(nil) } - return conn.rowsProvider.QueryRow(conn.structFieldMapper, query, args...) + return conn.rowsProvider.QueryRow(query, args...) } -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { if conn.ctx.Err() != nil { - return sqldb.RowsScannerWithError(conn.ctx.Err()) + return sqldb.RowsWithError(conn.ctx.Err()) } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, query) } if conn.rowsProvider == nil { - return sqldb.RowsScannerWithError(nil) + return sqldb.RowsWithError(nil) } - return conn.rowsProvider.QueryRows(conn.structFieldMapper, query, args...) + return conn.rowsProvider.QueryRows(query, args...) } func (conn *connection) IsTransaction() bool { diff --git a/mockconn/onetimerowsprovider.go b/mockconn/onetimerowsprovider.go index 44c4b39..059f9ee 100644 --- a/mockconn/onetimerowsprovider.go +++ b/mockconn/onetimerowsprovider.go @@ -11,19 +11,19 @@ import ( ) type OneTimeRowsProvider struct { - rowScanners map[string]sqldb.RowScanner - rowsScanners map[string]sqldb.RowsScanner + rowScanners map[string]sqldb.Row + rowsScanners map[string]sqldb.Rows mtx sync.Mutex } func NewOneTimeRowsProvider() *OneTimeRowsProvider { return &OneTimeRowsProvider{ - rowScanners: make(map[string]sqldb.RowScanner), - rowsScanners: make(map[string]sqldb.RowsScanner), + rowScanners: make(map[string]sqldb.Row), + rowsScanners: make(map[string]sqldb.Rows), } } -func (p *OneTimeRowsProvider) AddRowScannerQuery(scanner sqldb.RowScanner, query string, args ...any) { +func (p *OneTimeRowsProvider) AddRowQuery(scanner sqldb.Row, query string, args ...any) { p.mtx.Lock() defer p.mtx.Unlock() @@ -34,7 +34,7 @@ func (p *OneTimeRowsProvider) AddRowScannerQuery(scanner sqldb.RowScanner, query p.rowScanners[key] = scanner } -func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, query string, args ...any) { +func (p *OneTimeRowsProvider) AddRowsQuery(scanner sqldb.Rows, query string, args ...any) { p.mtx.Lock() defer p.mtx.Unlock() @@ -45,7 +45,7 @@ func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, que p.rowsScanners[key] = scanner } -func (p *OneTimeRowsProvider) QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +func (p *OneTimeRowsProvider) QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.Row { p.mtx.Lock() defer p.mtx.Unlock() @@ -55,7 +55,7 @@ func (p *OneTimeRowsProvider) QueryRow(structFieldMapper reflection.StructFieldM return scanner } -func (p *OneTimeRowsProvider) QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +func (p *OneTimeRowsProvider) QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.Rows { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/mockconn/rowsprovider.go b/mockconn/rowsprovider.go index a3074e6..74526dc 100644 --- a/mockconn/rowsprovider.go +++ b/mockconn/rowsprovider.go @@ -2,10 +2,9 @@ package mockconn import ( sqldb "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) type RowsProvider interface { - QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner - QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner + QueryRow(query string, args ...any) sqldb.Row + QueryRows(query string, args ...any) sqldb.Rows } diff --git a/mockconn/singlerowprovider.go b/mockconn/singlerowprovider.go index 0bbec5c..f1c9ff1 100644 --- a/mockconn/singlerowprovider.go +++ b/mockconn/singlerowprovider.go @@ -4,7 +4,6 @@ import ( "context" sqldb "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) // NewSingleRowProvider a RowsProvider implementation @@ -20,10 +19,10 @@ type singleRowProvider struct { argFmt sqldb.ParamPlaceholderFormatter } -func (p *singleRowProvider) QueryRow(mapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowScanner { - return sqldb.NewRowScanner(sqldb.RowAsRows(p.row), mapper, query, p.argFmt, args) +func (p *singleRowProvider) QueryRow(query string, args ...any) sqldb.Row { + return sqldb.NewRow(context.Background(), sqldb.RowAsRows(p.row), query, p.argFmt, args) } -func (p *singleRowProvider) QueryRows(mapper reflection.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { - return sqldb.NewRowsScanner(context.Background(), NewRows(p.row), mapper, query, p.argFmt, args) +func (p *singleRowProvider) QueryRows(query string, args ...any) sqldb.Rows { + return sqldb.NewRows(context.Background(), sqldb.NewRows(p.row), query, p.argFmt, args) } diff --git a/mysqlconn/connection.go b/mysqlconn/connection.go index ca09094..e78ebc9 100644 --- a/mysqlconn/connection.go +++ b/mysqlconn/connection.go @@ -7,7 +7,6 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) // New creates a new sqldb.Connection using the passed sqldb.Config @@ -25,10 +24,9 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { return nil, err } conn := &connection{ - ctx: ctx, - db: db, - config: config, - structFieldMapper: sqldb.DefaultStructFieldMapping, + ctx: ctx, + db: db, + config: config, } return conn, nil } @@ -47,10 +45,9 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldMapper reflection.StructFieldMapper + ctx context.Context + db *sql.DB + config *sqldb.Config } func (conn *connection) clone() *connection { @@ -69,16 +66,6 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return c } -func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldMapper = mapper - return c -} - -func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldMapper -} - func (conn *connection) Ping(timeout time.Duration) error { ctx := conn.ctx if timeout > 0 { @@ -119,25 +106,25 @@ func (conn *connection) Now() (now time.Time, err error) { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRow(conn.ctx, rows, query, conn, args) } -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRows(conn.ctx, rows, query, conn, args) } func (conn *connection) IsTransaction() bool { diff --git a/mysqlconn/transaction.go b/mysqlconn/transaction.go index 7e5c000..281a4d9 100644 --- a/mysqlconn/transaction.go +++ b/mysqlconn/transaction.go @@ -7,24 +7,21 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) type transaction struct { // The parent non-transaction connection is needed // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldMapper reflection.StructFieldMapper + parent *connection + tx *sql.Tx + opts *sql.TxOptions } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldMapper: parent.structFieldMapper, + parent: parent, + tx: tx, + opts: opts, } } @@ -44,16 +41,6 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts) } -func (conn *transaction) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldMapper = mapper - return c -} - -func (conn *transaction) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldMapper -} - func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } @@ -80,25 +67,25 @@ func (conn *transaction) Now() (now time.Time, err error) { func (conn *transaction) Exec(query string, args ...any) error { _, err := conn.tx.Exec(query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *transaction) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRow(conn.parent.ctx, rows, query, conn, args) } -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *transaction) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRows(conn.parent.ctx, rows, query, conn, args) } func (conn *transaction) IsTransaction() bool { diff --git a/pqconn/connection.go b/pqconn/connection.go index 687d026..2f610f8 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -7,7 +7,6 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) // New creates a new sqldb.Connection using the passed sqldb.Config @@ -25,10 +24,9 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { return nil, err } return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldMapper: sqldb.DefaultStructFieldMapping, + ctx: ctx, + db: db, + config: config, }, nil } @@ -46,10 +44,9 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldMapper reflection.StructFieldMapper + ctx context.Context + db *sql.DB + config *sqldb.Config } func (conn *connection) clone() *connection { @@ -68,16 +65,6 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return c } -func (conn *connection) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldMapper = mapper - return c -} - -func (conn *connection) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldMapper -} - func (conn *connection) Ping(timeout time.Duration) error { ctx := conn.ctx if timeout > 0 { @@ -118,25 +105,25 @@ func (conn *connection) Now() (now time.Time, err error) { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRow(conn.ctx, rows, query, conn, args) } -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return sqldb.NewRowsScanner(conn.ctx, rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRows(conn.ctx, rows, query, conn, args) } func (conn *connection) IsTransaction() bool { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index fd84ae5..dc01192 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -6,24 +6,21 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/reflection" ) type transaction struct { // The parent non-transaction connection is needed // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldMapper reflection.StructFieldMapper + parent *connection + tx *sql.Tx + opts *sql.TxOptions } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldMapper: parent.structFieldMapper, + parent: parent, + tx: tx, + opts: opts, } } @@ -43,16 +40,6 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts) } -func (conn *transaction) WithStructFieldMapper(mapper reflection.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldMapper = mapper - return c -} - -func (conn *transaction) StructFieldMapper() reflection.StructFieldMapper { - return conn.structFieldMapper -} - func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } @@ -78,26 +65,26 @@ func (conn *transaction) Now() (now time.Time, err error) { } func (conn *transaction) Exec(query string, args ...any) error { - _, err := conn.tx.Exec(query, args...) - return sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) + _, err := conn.tx.ExecContext(conn.parent.ctx, query, args...) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *transaction) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return sqldb.NewRowScanner(rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRow(conn.parent.ctx, rows, query, conn, args) } -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *transaction) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = sqldb.WrapNonNilErrorWithQuery(err, query, conn, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return sqldb.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldMapper, query, conn, args) + return sqldb.NewRows(conn.parent.ctx, rows, query, conn, args) } func (conn *transaction) IsTransaction() bool { diff --git a/foreachrow.go b/reflection/foreachrow.go similarity index 93% rename from foreachrow.go rename to reflection/foreachrow.go index ab639a4..b9448f5 100644 --- a/foreachrow.go +++ b/reflection/foreachrow.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "context" @@ -25,7 +25,7 @@ var ( // If a non nil error is returned from the callback, then this error // is returned immediately by this function without scanning further rows. // In case of zero rows, no error will be returned. -func forEachRowCallFunc(ctx context.Context, callback any) (f func(RowScanner) error, err error) { +func ForEachRowCallFunc(ctx context.Context, mapper StructFieldMapper, callback any) (f func(Row) error, err error) { val := reflect.ValueOf(callback) typ := val.Type() if typ.Kind() != reflect.Func { @@ -70,14 +70,14 @@ func forEachRowCallFunc(ctx context.Context, callback any) (f func(RowScanner) e return nil, fmt.Errorf("ForEachRowCallFunc callback function result must be of type error: %s", typ) } - f = func(row RowScanner) (err error) { + f = func(row Row) (err error) { // First scan row scannedValPtrs := make([]any, typ.NumIn()-firstArg) for i := range scannedValPtrs { scannedValPtrs[i] = reflect.New(typ.In(firstArg + i)).Interface() } if structArg { - err = row.ScanStruct(scannedValPtrs[0]) + err = ScanStruct(row, scannedValPtrs[0], mapper) } else { err = row.Scan(scannedValPtrs...) } diff --git a/foreachrow_test.go b/reflection/foreachrow_test.go similarity index 97% rename from foreachrow_test.go rename to reflection/foreachrow_test.go index 5f0f3cb..395b7ab 100644 --- a/foreachrow_test.go +++ b/reflection/foreachrow_test.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "testing" diff --git a/reflection/scan.go b/reflection/scan.go new file mode 100644 index 0000000..d8d5f89 --- /dev/null +++ b/reflection/scan.go @@ -0,0 +1,91 @@ +package reflection + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "time" +) + +func ScanValue(src driver.Value, dest reflect.Value) error { + if dest.Kind() == reflect.Interface { + if src != nil { + dest.Set(reflect.ValueOf(src)) + } else { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if dest.Addr().Type().Implements(typeOfSQLScanner) { + return dest.Addr().Interface().(sql.Scanner).Scan(src) + } + + switch x := src.(type) { + case int64: + switch dest.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dest.SetInt(x) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dest.SetUint(uint64(x)) + return nil + case reflect.Float32, reflect.Float64: + dest.SetFloat(float64(x)) + return nil + } + + case float64: + switch dest.Kind() { + case reflect.Float32, reflect.Float64: + dest.SetFloat(x) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dest.SetInt(int64(x)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dest.SetUint(uint64(x)) + return nil + } + + case bool: + dest.SetBool(x) + return nil + + case []byte: + switch { + case dest.Kind() == reflect.String: + dest.SetString(string(x)) + return nil + case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: + dest.Set(reflect.ValueOf(x)) + return nil + } + + case string: + switch { + case dest.Kind() == reflect.String: + dest.SetString(x) + return nil + case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: + dest.Set(reflect.ValueOf([]byte(x))) + return nil + } + + case time.Time: + if srcVal := reflect.ValueOf(src); srcVal.Type().AssignableTo(dest.Type()) { + dest.Set(srcVal) + return nil + } + + case nil: + switch dest.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dest.Set(reflect.Zero(dest.Type())) + return nil + } + } + + return fmt.Errorf("can't scan %#v as %s", src, dest.Type()) +} diff --git a/reflection/scanstruct_test.go b/reflection/scan_test.go similarity index 100% rename from reflection/scanstruct_test.go rename to reflection/scan_test.go diff --git a/reflection/scanslice.go b/reflection/scanslice.go index acac89a..f15cd41 100644 --- a/reflection/scanslice.go +++ b/reflection/scanslice.go @@ -6,14 +6,64 @@ import ( "errors" "fmt" "reflect" - "time" "github.com/domonda/go-types/nullable" ) -var ( - typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() -) +// // TODO doc +// // ScanSlice scans one value per row into one slice element of dest. +// // dest must be a pointer to a slice with a row value compatible element type. +// // In case of zero rows, dest will be set to nil and no error will be returned. +// // In case of an error, dest will not be modified. +// // It is an error to query more than one column.func (s *rowsScanner) ScanSlice(dest any) error { +// err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, nil) +// if err != nil { +// return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) +// } +// return nil +// } +// // ScanStructSlice scans every row into the struct fields of dest slice elements. +// // dest must be a pointer to a slice of structs or struct pointers. +// // In case of zero rows, dest will be set to nil and no error will be returned. +// // In case of an error, dest will not be modified. +// // Every mapped struct field must have a corresponding column in the query results. +// func (s *rowsScanner) ScanStructSlice(dest any) error { +// err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldMapper) +// if err != nil { +// return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) +// } +// return nil +// } + +// // ScanAllRowsAsStrings scans the values of all rows as strings. +// // Byte slices will be interpreted as strings, +// // nil (SQL NULL) will be converted to an empty string, +// // all other types are converted with fmt.Sprint. +// // If true is passed for headerRow, then a row +// // with the column names will be prepended. +// func (s *rowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { +// cols, err := s.rows.Columns() +// if err != nil { +// return nil, err +// } +// if headerRow { +// rows = [][]string{cols} +// } +// stringScannablePtrs := make([]any, len(cols)) +// err = s.ForEachRow(func(rowScanner RowScanner) error { +// row := make([]string, len(cols)) +// for i := range stringScannablePtrs { +// stringScannablePtrs[i] = (*StringScannable)(&row[i]) +// } +// err := rowScanner.Scan(stringScannablePtrs...) +// if err != nil { +// return err +// } +// rows = append(rows, row) +// return nil +// }) +// return rows, err +// } // ScanRowsAsSlice scans all srcRows as slice into dest. // The rows must either have only one column compatible with the element type of the slice, @@ -125,85 +175,3 @@ func (a *SliceScanner) scanString(src string) error { a.destSlice.Set(newSlice) return nil } - -func ScanValue(src any, dest reflect.Value) error { - if dest.Kind() == reflect.Interface { - if src != nil { - dest.Set(reflect.ValueOf(src)) - } else { - dest.Set(reflect.Zero(dest.Type())) - } - return nil - } - - if dest.Addr().Type().Implements(typeOfSQLScanner) { - return dest.Addr().Interface().(sql.Scanner).Scan(src) - } - - switch x := src.(type) { - case int64: - switch dest.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dest.SetInt(x) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dest.SetUint(uint64(x)) - return nil - case reflect.Float32, reflect.Float64: - dest.SetFloat(float64(x)) - return nil - } - - case float64: - switch dest.Kind() { - case reflect.Float32, reflect.Float64: - dest.SetFloat(x) - return nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dest.SetInt(int64(x)) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dest.SetUint(uint64(x)) - return nil - } - - case bool: - dest.SetBool(x) - return nil - - case []byte: - switch { - case dest.Kind() == reflect.String: - dest.SetString(string(x)) - return nil - case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: - dest.Set(reflect.ValueOf(x)) - return nil - } - - case string: - switch { - case dest.Kind() == reflect.String: - dest.SetString(x) - return nil - case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: - dest.Set(reflect.ValueOf([]byte(x))) - return nil - } - - case time.Time: - if srcVal := reflect.ValueOf(src); srcVal.Type().AssignableTo(dest.Type()) { - dest.Set(srcVal) - return nil - } - - case nil: - switch dest.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: - dest.Set(reflect.Zero(dest.Type())) - return nil - } - } - - return fmt.Errorf("can't scan %#v as %s", src, dest.Type()) -} diff --git a/reflection/scanstruct.go b/reflection/scanstruct.go index 3c5a77a..bb3afc8 100644 --- a/reflection/scanstruct.go +++ b/reflection/scanstruct.go @@ -5,7 +5,8 @@ import ( "reflect" ) -func ScanStruct(srcRow Row, destStruct any, namer StructFieldMapper) error { +// ScanStruct scans values of a srcRow into a destStruct which must be passed as pointer. +func ScanStruct(srcRow Row, destStruct any, mapper StructFieldMapper) error { v := reflect.ValueOf(destStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -33,7 +34,7 @@ func ScanStruct(srcRow Row, destStruct any, namer StructFieldMapper) error { return err } - fieldPointers, err := ReflectStructColumnPointers(v, namer, columns) + fieldPointers, err := ReflectStructColumnPointers(v, mapper, columns) if err != nil { return fmt.Errorf("ScanStruct: %w", err) } diff --git a/row.go b/row.go index a82431a..352cc80 100644 --- a/row.go +++ b/row.go @@ -1,5 +1,11 @@ package sqldb +import ( + "context" + "database/sql" + "errors" +) + // Row is an interface with the methods of sql.Rows // that are needed for ScanStruct. // Allows mocking for tests without an SQL driver. @@ -11,3 +17,81 @@ type Row interface { // number of columns in Rows. Scan(dest ...any) error } + +/////////////////////////////////////////////////////////////////////////////// + +// RowWithError returns a dummy Row +// where all methods return the passed error. +func RowWithError(err error) Row { + return rowWithError{err} +} + +type rowWithError struct{ err error } + +func (e rowWithError) Columns() ([]string, error) { return nil, e.err } +func (e rowWithError) Scan(dest ...any) error { return e.err } + +/////////////////////////////////////////////////////////////////////////////// + +type rowWrapper struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows *sql.Rows + conn Connection // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewRow(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Row { + return &rowWrapper{ctx, rows, conn, query, args} +} + +func (r *rowWrapper) Columns() ([]string, error) { + columns, err := r.rows.Columns() + if err != nil { + return nil, WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + return columns, nil +} + +func (r *rowWrapper) Scan(dest ...any) (err error) { + defer func() { + err = combineTwoErrors(err, r.rows.Close()) + if err != nil { + err = WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + }() + + if r.ctx.Err() != nil { + return r.ctx.Err() + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err = r.rows.Scan(dest...) + if err != nil { + return err + } + return r.rows.Close() +} diff --git a/rows.go b/rows.go index f1f75f8..82d7570 100644 --- a/rows.go +++ b/rows.go @@ -1,41 +1,96 @@ package sqldb -// Rows is an interface with the methods of sql.Rows -// that are needed for ScanSlice. -// Allows mocking for tests without an SQL driver. +import ( + "context" + "database/sql" +) + type Rows interface { - Row + // ForEachRow will call the passed callback with a RowScanner for every row. + // In case of zero rows, no error will be returned. + ForEachRow(callback func(Row) error) error // Close closes the Rows, preventing further enumeration. If Next is called // and returns false and there are no further result sets, // the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. Close() error +} + +/////////////////////////////////////////////////////////////////////////////// + +// RowsWithError returns dummy Rows +// where all methods return the passed error. +func RowsWithError(err error) Rows { + return rowsWithError{err} +} + +type rowsWithError struct{ err error } + +func (e rowsWithError) ForEachRow(func(Row) error) error { return e.err } +func (e rowsWithError) Close() error { return e.err } + +/////////////////////////////////////////////////////////////////////////////// - // Next prepares the next result row for reading with the Scan method. It - // returns true on success, or false if there is no next result row or an error - // happened while preparing it. Err should be consulted to distinguish between - // the two cases. - // - // Every call to Scan, even the first one, must be preceded by a call to Next. - Next() bool - - // Err returns the error, if any, that was encountered during iteration. - // Err may be called after an explicit or implicit Close. - Err() error +type rowsWrapper struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows *sql.Rows + conn Connection // for error wrapping + query string // for error wrapping + args []any // for error wrapping } -// RowAsRows implements the methods of Rows for a Row as no-ops. -// Note that Next() always returns true leading to an endless loop -// if used to scan multiple rows. -func RowAsRows(row Row) Rows { - return rowAsRows{Row: row} +func NewRows(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Rows { + return &rowsWrapper{ctx, rows, conn, query, args} } -type rowAsRows struct { - Row +func (r *rowsWrapper) ForEachRow(callback func(Row) error) (err error) { + defer func() { + err = combineTwoErrors(err, r.rows.Close()) + if err != nil { + err = WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + }() + + for r.rows.Next() { + if r.ctx.Err() != nil { + return r.ctx.Err() + } + + err := callback(r.rows) + if err != nil { + return err + } + } + return r.rows.Err() +} + +func (r *rowsWrapper) Close() error { + return r.rows.Close() } -func (rowAsRows) Close() error { return nil } -func (rowAsRows) Next() bool { return true } -func (rowAsRows) Err() error { return nil } +/////////////////////////////////////////////////////////////////////////////// + +// RowAsRows returns a single Rows wrapped as a Rows implementation. +// func RowAsRows(row Row) Rows { +// return &rowAsRows{row: row, closed: false} +// } + +// type rowAsRows struct { +// row Row +// closed bool +// } + +// func (r *rowAsRows) ForEachRow(callback func(Row) error) error { +// if r.closed { +// return errors.New("Rows are closed") +// } +// err := callback(r.row) +// r.closed = true +// return err +// } + +// func (r *rowAsRows) Close() error { +// r.closed = true +// return nil +// } diff --git a/rowscanner.go b/rowscanner.go deleted file mode 100644 index 7a754a3..0000000 --- a/rowscanner.go +++ /dev/null @@ -1,151 +0,0 @@ -package sqldb - -import ( - "database/sql" - - "github.com/domonda/go-sqldb/reflection" -) - -// RowScanner scans the values from a single row. -type RowScanner interface { - // Scan values of a row into dest variables, which must be passed as pointers. - Scan(dest ...any) error - - // ScanStruct scans values of a row into a dest struct which must be passed as pointer. - ScanStruct(dest any) error - - // ScanValues returns the values of a row exactly how they are - // passed from the database driver to an sql.Scanner. - // Byte slices will be copied. - ScanValues() ([]any, error) - - // ScanStrings scans the values of a row as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint(src). - ScanStrings() ([]string, error) - - // Columns returns the column names. - Columns() ([]string, error) -} - -var ( - _ RowScanner = &rowScanner{} - _ RowScanner = CurrentRowScanner{} - _ RowScanner = SingleRowScanner{} -) - -// rowScanner implements rowScanner for a sql.Row -type rowScanner struct { - rows Rows - structFieldMapper reflection.StructFieldMapper - query string // for error wrapping - argFmt ParamPlaceholderFormatter // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldMapper reflection.StructFieldMapper, query string, argFmt ParamPlaceholderFormatter, args []any) *rowScanner { - return &rowScanner{rows, structFieldMapper, query, argFmt, args} -} - -func (s *rowScanner) Scan(dest ...any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *rowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return reflection.ScanStruct(s.rows, dest, s.structFieldMapper) -} - -func (s *rowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *rowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *rowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close -type CurrentRowScanner struct { - Rows Rows - StructFieldMapper reflection.StructFieldMapper -} - -func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) -} - -func (s CurrentRowScanner) ScanStruct(dest any) error { - return reflection.ScanStruct(s.Rows, dest, s.StructFieldMapper) -} - -func (s CurrentRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Rows) -} - -func (s CurrentRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Rows) -} - -func (s CurrentRowScanner) Columns() ([]string, error) { - return s.Rows.Columns() -} - -// SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper reflection.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return reflection.ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} diff --git a/rowsscanner.go b/rowsscanner.go deleted file mode 100644 index d5e4c74..0000000 --- a/rowsscanner.go +++ /dev/null @@ -1,139 +0,0 @@ -package sqldb - -import ( - "context" - "fmt" - - "github.com/domonda/go-sqldb/reflection" -) - -// RowsScanner scans the values from multiple rows. -type RowsScanner interface { - // ScanSlice scans one value per row into one slice element of dest. - // dest must be a pointer to a slice with a row value compatible element type. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // It is an error to query more than one column. - ScanSlice(dest any) error - - // ScanStructSlice scans every row into the struct fields of dest slice elements. - // dest must be a pointer to a slice of structs or struct pointers. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // Every mapped struct field must have a corresponding column in the query results. - ScanStructSlice(dest any) error - - // ScanAllRowsAsStrings scans the values of all rows as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint. - // If true is passed for headerRow, then a row - // with the column names will be prepended. - ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) - - // Columns returns the column names. - Columns() ([]string, error) - - // ForEachRow will call the passed callback with a RowScanner for every row. - // In case of zero rows, no error will be returned. - ForEachRow(callback func(RowScanner) error) error - - // ForEachRowCall will call the passed callback with scanned values or a struct for every row. - // If the callback function has a single struct or struct pointer argument, - // then RowScanner.ScanStruct will be used per row, - // else RowScanner.Scan will be used for all arguments of the callback. - // If the function has a context.Context as first argument, - // then the context of the query call will be passed on. - // The callback can have no result or a single error result value. - // If a non nil error is returned from the callback, then this error - // is returned immediately by this function without scanning further rows. - // In case of zero rows, no error will be returned. - ForEachRowCall(callback any) error -} - -var _ RowsScanner = &rowsScanner{} - -// rowsScanner implements rowsScanner with Rows -type rowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldMapper reflection.StructFieldMapper - query string // for error wrapping - argFmt ParamPlaceholderFormatter // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldMapper reflection.StructFieldMapper, query string, argFmt ParamPlaceholderFormatter, args []any) *rowsScanner { - return &rowsScanner{ctx, rows, structFieldMapper, query, argFmt, args} -} - -func (s *rowsScanner) ScanSlice(dest any) error { - err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *rowsScanner) ScanStructSlice(dest any) error { - err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldMapper) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *rowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *rowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *rowsScanner) ForEachRow(callback func(RowScanner) error) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldMapper}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *rowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := forEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} diff --git a/transaction.go b/transaction.go index 036568b..755505f 100644 --- a/transaction.go +++ b/transaction.go @@ -1,6 +1,7 @@ package sqldb import ( + "context" "database/sql" "errors" "fmt" @@ -14,15 +15,15 @@ import ( // are stricter than the options of the parent transaction. // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - if parentOpts, parentIsTx := parentConn.TransactionOptions(); parentIsTx { - err = CheckTxOptionsCompatibility(parentOpts, opts, parentConn.Config().DefaultIsolationLevel) +func Transaction(ctx context.Context, parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { + if parentConn.IsTransaction() { + err = CheckConnectionTxOptionsCompatibility(parentConn, opts) if err != nil { return err } return txFunc(parentConn) } - return IsolatedTransaction(parentConn, opts, txFunc) + return IsolatedTransaction(ctx, parentConn, opts, txFunc) } // IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. @@ -30,8 +31,8 @@ func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Conn // If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. // Errors and panics from txFunc will rollback the transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - tx, e := parentConn.Begin(opts) +func IsolatedTransaction(ctx context.Context, parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { + tx, e := parentConn.Begin(ctx, opts) if e != nil { return fmt.Errorf("Transaction Begin error: %w", e) } @@ -97,3 +98,7 @@ func CheckTxOptionsCompatibility(parent, child *sql.TxOptions, defaultIsolation } return nil } + +func CheckConnectionTxOptionsCompatibility(parentTx Connection, childTxOpts *sql.TxOptions) error { + return CheckTxOptionsCompatibility(parentTx.TxOptions(), childTxOpts, parentTx.Config().DefaultIsolationLevel) +} diff --git a/txconnection.go b/txconnection.go new file mode 100644 index 0000000..d0d58cc --- /dev/null +++ b/txconnection.go @@ -0,0 +1,89 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +type TxConnection struct { + Parent Connection + Tx *sql.Tx + Opts *sql.TxOptions +} + +func (c *TxConnection) Config() *Config { + return c.Parent.Config() +} + +func (c *TxConnection) Stats() sql.DBStats { + return c.Parent.Stats() +} + +func (c *TxConnection) Ping(ctx context.Context, timeout time.Duration) error { + return c.Parent.Ping(ctx, timeout) +} + +func (c *TxConnection) Err() error { + return c.Parent.Err() +} + +func (c *TxConnection) Exec(ctx context.Context, query string, args ...any) error { + _, err := c.Tx.ExecContext(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter) + } + return nil +} + +func (c *TxConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + rows, err := c.Tx.QueryContext(ctx, query, args...) + if err != nil { + return RowWithError(WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter)) + } + return NewRow(ctx, rows, c, query, args) +} + +func (c *TxConnection) QueryRows(ctx context.Context, query string, args ...any) Rows { + rows, err := c.Tx.QueryContext(ctx, query, args...) + if err != nil { + return RowsWithError(WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter)) + } + return NewRows(ctx, rows, c, query, args) +} + +func (c *TxConnection) IsTransaction() bool { + return true +} + +func (c *TxConnection) TxOptions() *sql.TxOptions { + return c.Opts +} + +func (c *TxConnection) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + return nil, ErrWithinTransaction +} + +func (c *TxConnection) Commit() error { + return c.Tx.Commit() +} + +func (c *TxConnection) Rollback() error { + return c.Tx.Rollback() +} + +func (c *TxConnection) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return ErrWithinTransaction +} + +func (c *TxConnection) UnlistenChannel(channel string) error { + return ErrWithinTransaction +} + +func (c *TxConnection) IsListeningOnChannel(channel string) bool { + return false +} + +func (c *TxConnection) Close() error { + return c.Tx.Rollback() +} From 7c9eb3e7265eddfcf16367b2102b133d78bd1c5f Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Mon, 27 Jun 2022 10:41:46 +0200 Subject: [PATCH 11/12] WIP --- columnfilter.go | 42 +++-- go.mod | 25 ++- go.sum | 101 ++++++++++- reflection/columnfilter.go | 2 +- reflection/reflectstruct.go | 234 +++++++++++++------------- reflection/scanstruct.go | 52 ------ reflection/structfieldflags.go | 25 +++ reflection/structfieldmapping.go | 34 +--- reflection/structfieldmapping_test.go | 12 +- reflection/structmapper.go | 9 + reflection/structmapping.go | 188 +++++++++++++++++++++ reflection/taggedstructmapping.go | 142 ++++++++++++++++ row.go | 16 +- rows.go | 16 +- sqliteconn/connection.go | 27 +++ 15 files changed, 676 insertions(+), 249 deletions(-) delete mode 100644 reflection/scanstruct.go create mode 100644 reflection/structfieldflags.go create mode 100644 reflection/structmapper.go create mode 100644 reflection/structmapping.go create mode 100644 reflection/taggedstructmapping.go create mode 100644 sqliteconn/connection.go diff --git a/columnfilter.go b/columnfilter.go index fab3fe1..cc829f3 100644 --- a/columnfilter.go +++ b/columnfilter.go @@ -7,17 +7,17 @@ import ( ) type ColumnFilter interface { - IgnoreColumn(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool + IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool } -type ColumnFilterFunc func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool +type ColumnFilterFunc func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool -func (f ColumnFilterFunc) IgnoreColumn(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func (f ColumnFilterFunc) IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return f(name, flags, fieldType, fieldValue) } func IgnoreColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if name == ignore { return true @@ -28,7 +28,7 @@ func IgnoreColumns(names ...string) ColumnFilter { } func OnlyColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if name == include { return false @@ -39,7 +39,7 @@ func OnlyColumns(names ...string) ColumnFilter { } func IgnoreStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if fieldType.Name == ignore { return true @@ -50,7 +50,7 @@ func IgnoreStructFields(names ...string) ColumnFilter { } func OnlyStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if fieldType.Name == include { return false @@ -60,32 +60,40 @@ func OnlyStructFields(names ...string) ColumnFilter { }) } -func IgnoreFlags(ignore reflection.FieldFlag) ColumnFilter { - return ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func IgnoreFlags(ignore reflection.StructFieldFlags) ColumnFilter { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags&ignore != 0 }) } -var IgnoreDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() +var IgnoreHasDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return flags.HasDefault() }) -var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.PrimaryKey() }) -var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.ReadOnly() }) -var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNull(fieldValue) }) -var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNullOrZero(fieldValue) }) -var IgnoreNullOrZeroDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() && IsNullOrZero(fieldValue) +var IgnoreHasDefaultNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return flags.HasDefault() && IsNullOrZero(fieldValue) }) + +type noColumnFilter struct{} + +func (noColumnFilter) IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return false +} + +var AllColumns noColumnFilter diff --git a/go.mod b/go.mod index ef2493e..cd0a778 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,38 @@ module github.com/domonda/go-sqldb go 1.18 require ( - github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71 - github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191 + github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645 + github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57 github.com/go-sql-driver/mysql v1.6.0 github.com/lib/pq v1.10.6 - github.com/stretchr/testify v1.7.1 - golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f + github.com/stretchr/testify v1.7.5 + golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d + modernc.org/sqlite v1.17.3 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a // indirect + github.com/google/uuid v1.3.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/sys v0.0.0-20220624220833-87e55d714810 // indirect + golang.org/x/tools v0.1.11 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.36.0 // indirect + modernc.org/ccgo/v3 v3.16.6 // indirect + modernc.org/libc v1.16.11 // indirect + modernc.org/mathutil v1.4.1 // indirect + modernc.org/memory v1.1.1 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.2 // indirect + modernc.org/token v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index babe5fc..b76e511 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,24 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71 h1:WRag+fUJENLRM8N/wp6gf/0i1aEkLY9prNgoFQsWeso= -github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71/go.mod h1:suiFfPp8l6I+OOaKgPK/bfX7Ci9ZtFRgPh5VNE0HPao= +github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645 h1:hCCfGvOsbejnNPUdqD/wtE/t4pRjEn8/706tRqxUmck= +github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645/go.mod h1:WvIoE59Dfs0hhB2GYSlwowlBr2WWGXf/F74bg6HWUpQ= github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a h1:6/Is0KGl5Ot3E8ZLAgAFWYiSRdU+3t3jL38+5yIlCV4= github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= -github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191 h1:NcOIFS41zSztJog+aPw48HV8oVhRQPV0B6M6CshwFqc= -github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191/go.mod h1:qZTRjdjIXo3g+8PUhfpkKbMPGsLVTuF3H7/AX5CzNeQ= +github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57 h1:ivIpyltPSRPx1CdqqcXUi+hEp3SyFt6RR6B19pwpYOY= +github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57/go.mod h1:jqmELFrQI8hv+uaTNjxht99Wn+14jbUoSmwkbnxaA/g= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -18,18 +26,97 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= +github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.5 h1:s5PTfem8p8EbKQOctVV53k6jCJt3UX4IEJzwh+C324Q= +github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a h1:9vfYtqoyrPw08TbSLxkSXEflp6iXa3RL86Qjs+DrVas= github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a/go.mod h1:6Hnd2/4g3Tpt6TjvxHx8wXOZziwApVxRdIGkr7vNpXs= -golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f h1:KK6mxegmt5hGJRcAnEDjSNLxIRhZxDcgwMbcO/lMCRM= -golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d h1:vtUKgx8dahOomfFzLREU8nSv25YHnTgLBn4rDnWZdU0= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220624220833-87e55d714810 h1:rHZQSjJdAI4Xf5Qzeh2bBc5YJIkPFVM6oDtMFYmgws0= +golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.36.0 h1:0kmRkTmqNidmu3c7BNDSdVHCxXCkWLmWmCIVX4LUboo= +modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= +modernc.org/ccgo/v3 v3.0.0-20220428102840-41399a37e894/go.mod h1:eI31LL8EwEBKPpNpA4bU1/i+sKOwOrQy8D87zWUcRZc= +modernc.org/ccgo/v3 v3.0.0-20220430103911-bc99d88307be/go.mod h1:bwdAnOoaIt8Ax9YdWGjxWsdkPcZyRPHqrOvJxaKAKGw= +modernc.org/ccgo/v3 v3.16.4/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccgo/v3 v3.16.6 h1:3l18poV+iUemQ98O3X5OMr97LOqlzis+ytivU4NqGhA= +modernc.org/ccgo/v3 v3.16.6/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= +modernc.org/libc v0.0.0-20220428101251-2d5f3daf273b/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= +modernc.org/libc v1.16.0/go.mod h1:N4LD6DBE9cf+Dzf9buBlzVJndKr/iJHG97vGLHYnb5A= +modernc.org/libc v1.16.1/go.mod h1:JjJE0eu4yeK7tab2n4S1w8tlWd9MxXLRzheaRnAKymU= +modernc.org/libc v1.16.7/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/libc v1.16.11 h1:rR2BPB5e9zUm9gYqDgR0hUxcSmjgtmAL79lRObBLfPU= +modernc.org/libc v1.16.11/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.1 h1:ij3fYGe8zBF4Vu+g0oT7mB06r8sqGWKuJu1yXeR4by8= +modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.1.1 h1:bDOL0DIDLQv7bWhP3gMvIrnoFw+Eo6F7a2QK9HPDiFU= +modernc.org/memory v1.1.1/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= +modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.17.3 h1:iE+coC5g17LtByDYDWKpR6m2Z9022YrSh3bumwOnIrI= +modernc.org/sqlite v1.17.3/go.mod h1:10hPVYar9C0kfXuTWGz8s0XtB8uAGymUy51ZzStYe3k= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= +modernc.org/strutil v1.1.2 h1:iFBDH6j1Z0bN/Q9udJnnFoFpENA4252qe/7/5woE5MI= +modernc.org/strutil v1.1.2/go.mod h1:OYajnUAcI/MX+XD/Wx7v1bbdvcQSvxgtb0gC+u3d3eg= +modernc.org/tcl v1.13.1 h1:npxzTwFTZYM8ghWicVIX1cRWzj7Nd8i6AqqX2p+IYao= +modernc.org/tcl v1.13.1/go.mod h1:XOLfOwzhkljL4itZkK6T72ckMgvj0BDsnKNdZVUOecw= +modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= +modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM= +modernc.org/z v1.5.1/go.mod h1:eWFB510QWW5Th9YGZT81s+LwvaAs3Q2yr4sP0rmLkv8= diff --git a/reflection/columnfilter.go b/reflection/columnfilter.go index a94321b..112b271 100644 --- a/reflection/columnfilter.go +++ b/reflection/columnfilter.go @@ -5,5 +5,5 @@ import ( ) type ColumnFilter interface { - IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool + IgnoreColumn(*StructColumn, reflect.Value) bool } diff --git a/reflection/reflectstruct.go b/reflection/reflectstruct.go index 1c5e36d..5a2868d 100644 --- a/reflection/reflectstruct.go +++ b/reflection/reflectstruct.go @@ -1,131 +1,131 @@ package reflection -import ( - "errors" - "fmt" - "reflect" - "strings" +// import ( +// "errors" +// "fmt" +// "reflect" +// "strings" - "golang.org/x/exp/slices" -) +// "golang.org/x/exp/slices" +// ) -func ReflectStructValues(structVal reflect.Value, mapper StructFieldMapper, ignoreColumns []ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { - structType := structVal.Type() - for i := 0; i < structType.NumField(); i++ { - fieldType := structType.Field(i) - fieldTable, column, flags, use := mapper.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) +// func ReflectStructValues(structVal reflect.Value, mapper StructFieldMapper, ignoreColumns []ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { +// structType := structVal.Type() +// for i := 0; i < structType.NumField(); i++ { +// fieldType := structType.Field(i) +// fieldTable, column, flags, use := mapper.MapStructField(fieldType) +// if !use { +// continue +// } +// fieldValue := structVal.Field(i) - if column == "" { - // Embedded struct field - fieldTable, columnsEmbed, pkColsEmbed, valuesEmbed, err := ReflectStructValues(fieldValue, mapper, ignoreColumns) - if err != nil { - return "", nil, nil, nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) - } - table = fieldTable - } - for _, pkCol := range pkColsEmbed { - pkCols = append(pkCols, pkCol+len(columns)) - } - columns = append(columns, columnsEmbed...) - values = append(values, valuesEmbed...) - continue - } +// if column == "" { +// // Embedded struct field +// fieldTable, columnsEmbed, pkColsEmbed, valuesEmbed, err := ReflectStructValues(fieldValue, mapper, ignoreColumns) +// if err != nil { +// return "", nil, nil, nil, err +// } +// if fieldTable != "" && fieldTable != table { +// if table != "" { +// return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) +// } +// table = fieldTable +// } +// for _, pkCol := range pkColsEmbed { +// pkCols = append(pkCols, pkCol+len(columns)) +// } +// columns = append(columns, columnsEmbed...) +// values = append(values, valuesEmbed...) +// continue +// } - if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { - continue - } +// if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { +// continue +// } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) - } - table = fieldTable - } - if flags.PrimaryKey() { - pkCols = append(pkCols, len(columns)) - } - columns = append(columns, column) - values = append(values, fieldValue.Interface()) - } - return table, columns, pkCols, values, nil -} +// if fieldTable != "" && fieldTable != table { +// if table != "" { +// return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) +// } +// table = fieldTable +// } +// if flags.PrimaryKey() { +// pkCols = append(pkCols, len(columns)) +// } +// columns = append(columns, column) +// values = append(values, fieldValue.Interface()) +// } +// return table, columns, pkCols, values, nil +// } -func ReflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string) (pointers []any, err error) { - if len(columns) == 0 { - return nil, errors.New("no columns") - } - pointers = make([]any, len(columns)) - err = reflectStructColumnPointers(structVal, mapper, columns, pointers) - if err != nil { - return nil, err - } - for _, ptr := range pointers { - if ptr != nil { - continue - } - nilCols := new(strings.Builder) - for i, ptr := range pointers { - if ptr != nil { - continue - } - if nilCols.Len() > 0 { - nilCols.WriteString(", ") - } - fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) - } - return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) - } - return pointers, nil -} +// func ReflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string) (pointers []any, err error) { +// if len(columns) == 0 { +// return nil, errors.New("no columns") +// } +// pointers = make([]any, len(columns)) +// err = reflectStructColumnPointers(structVal, mapper, columns, pointers) +// if err != nil { +// return nil, err +// } +// for _, ptr := range pointers { +// if ptr != nil { +// continue +// } +// nilCols := new(strings.Builder) +// for i, ptr := range pointers { +// if ptr != nil { +// continue +// } +// if nilCols.Len() > 0 { +// nilCols.WriteString(", ") +// } +// fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) +// } +// return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) +// } +// return pointers, nil +// } -func reflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string, pointers []any) error { - var ( - structType = structVal.Type() - ) - for i := 0; i < structType.NumField(); i++ { - fieldType := structType.Field(i) - _, column, _, use := mapper.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) +// func reflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string, pointers []any) error { +// var ( +// structType = structVal.Type() +// ) +// for i := 0; i < structType.NumField(); i++ { +// fieldType := structType.Field(i) +// _, column, _, use := mapper.MapStructField(fieldType) +// if !use { +// continue +// } +// fieldValue := structVal.Field(i) - if column == "" { - // Embedded struct field - err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers) - if err != nil { - return err - } - continue - } +// if column == "" { +// // Embedded struct field +// err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers) +// if err != nil { +// return err +// } +// continue +// } - colIndex := slices.Index(columns, column) - if colIndex == -1 { - continue - } +// colIndex := slices.Index(columns, column) +// if colIndex == -1 { +// continue +// } - if pointers[colIndex] != nil { - return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) - } +// if pointers[colIndex] != nil { +// return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) +// } - pointers[colIndex] = fieldValue.Addr().Interface() - } - return nil -} +// pointers[colIndex] = fieldValue.Addr().Interface() +// } +// return nil +// } -func ignoreColumn(filters []ColumnFilter, name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, filter := range filters { - if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { - return true - } - } - return false -} +// func ignoreColumn(filters []ColumnFilter, name string, flags StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { +// for _, filter := range filters { +// if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { +// return true +// } +// } +// return false +// } diff --git a/reflection/scanstruct.go b/reflection/scanstruct.go deleted file mode 100644 index bb3afc8..0000000 --- a/reflection/scanstruct.go +++ /dev/null @@ -1,52 +0,0 @@ -package reflection - -import ( - "fmt" - "reflect" -) - -// ScanStruct scans values of a srcRow into a destStruct which must be passed as pointer. -func ScanStruct(srcRow Row, destStruct any, mapper StructFieldMapper) error { - v := reflect.ValueOf(destStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - - var ( - setDestStructPtr = false - destStructPtr reflect.Value - newStructPtr reflect.Value - ) - if v.Kind() == reflect.Ptr && v.IsNil() && v.CanSet() { - // Got a nil pointer that we can set with a newly allocated struct - setDestStructPtr = true - destStructPtr = v - newStructPtr = reflect.New(v.Type().Elem()) - // Continue with the newly allocated struct - v = newStructPtr.Elem() - } - if v.Kind() != reflect.Struct { - return fmt.Errorf("ScanStruct: expected struct but got %T", destStruct) - } - - columns, err := srcRow.Columns() - if err != nil { - return err - } - - fieldPointers, err := ReflectStructColumnPointers(v, mapper, columns) - if err != nil { - return fmt.Errorf("ScanStruct: %w", err) - } - - err = srcRow.Scan(fieldPointers...) - if err != nil { - return err - } - - if setDestStructPtr { - destStructPtr.Set(newStructPtr) - } - - return nil -} diff --git a/reflection/structfieldflags.go b/reflection/structfieldflags.go new file mode 100644 index 0000000..71878a1 --- /dev/null +++ b/reflection/structfieldflags.go @@ -0,0 +1,25 @@ +package reflection + +// StructFieldFlags is a bitmask for special properties +// of how struct fields relate to database columns. +type StructFieldFlags uint + +// PrimaryKey indicates if FlagPrimaryKey is set +func (f StructFieldFlags) PrimaryKey() bool { return f&FlagPrimaryKey != 0 } + +// ReadOnly indicates if FlagReadOnly is set +func (f StructFieldFlags) ReadOnly() bool { return f&FlagReadOnly != 0 } + +// HasDefault indicates if FlagHasDefault is set +func (f StructFieldFlags) HasDefault() bool { return f&FlagHasDefault != 0 } + +const ( + // FlagPrimaryKey marks a field as primary key + FlagPrimaryKey StructFieldFlags = 1 << iota + + // FlagReadOnly marks a field as read-only + FlagReadOnly + + // FlagHasDefault marks a field as having a column default value + FlagHasDefault +) diff --git a/reflection/structfieldmapping.go b/reflection/structfieldmapping.go index c0c836e..c3a9b0f 100644 --- a/reflection/structfieldmapping.go +++ b/reflection/structfieldmapping.go @@ -7,30 +7,6 @@ import ( "unicode" ) -// FieldFlag is a bitmask for special properties -// of how struct fields relate to database columns. -type FieldFlag uint - -// PrimaryKey indicates if FieldFlagPrimaryKey is set -func (f FieldFlag) PrimaryKey() bool { return f&FieldFlagPrimaryKey != 0 } - -// ReadOnly indicates if FieldFlagReadOnly is set -func (f FieldFlag) ReadOnly() bool { return f&FieldFlagReadOnly != 0 } - -// Default indicates if FieldFlagDefault is set -func (f FieldFlag) Default() bool { return f&FieldFlagDefault != 0 } - -const ( - // FieldFlagPrimaryKey marks a field as primary key - FieldFlagPrimaryKey FieldFlag = 1 << iota - - // FieldFlagReadOnly marks a field as read-only - FieldFlagReadOnly - - // FieldFlagDefault marks a field as having a column default value - FieldFlagDefault -) - // StructFieldMapper is used to map struct type fields to column names // and indicate special column properies via flags. type StructFieldMapper interface { @@ -39,7 +15,7 @@ type StructFieldMapper interface { // If false is returned for use then the field is not mapped. // An empty name and true for use indicates an embedded struct // field whose fields should be recursively mapped. - MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) + MapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) } // NewTaggedStructFieldMapping returns a default mapping. @@ -74,7 +50,7 @@ type TaggedStructFieldMapping struct { UntaggedNameFunc func(fieldName string) string } -func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { +func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) { if field.Anonymous { column, hasTag := field.Tag.Lookup(m.NameTag) if !hasTag { @@ -107,12 +83,12 @@ func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (ta case "": // Ignore empty flags case m.PrimaryKey: - flags |= FieldFlagPrimaryKey + flags |= FlagPrimaryKey table = value case m.ReadOnly: - flags |= FieldFlagReadOnly + flags |= FlagReadOnly case m.Default: - flags |= FieldFlagDefault + flags |= FlagHasDefault } } } else { diff --git a/reflection/structfieldmapping_test.go b/reflection/structfieldmapping_test.go index 2473639..e3fe80c 100644 --- a/reflection/structfieldmapping_test.go +++ b/reflection/structfieldmapping_test.go @@ -55,18 +55,18 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { structField reflect.StructField wantTable string wantColumn string - wantFlags FieldFlag + wantFlags StructFieldFlags wantOk bool }{ - {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantOk: true}, - {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantOk: true}, + {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FlagPrimaryKey, wantOk: true}, + {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FlagPrimaryKey, wantOk: true}, {name: "named_str", structField: st.Field(2), wantColumn: "named_str", wantFlags: 0, wantOk: true}, - {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FieldFlagReadOnly, wantOk: true}, + {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FlagReadOnly, wantOk: true}, {name: "untagged_field", structField: st.Field(4), wantColumn: "untagged_field", wantFlags: 0, wantOk: true}, {name: "ignore", structField: st.Field(5), wantColumn: "", wantFlags: 0, wantOk: false}, - {name: "pk_read_only", structField: st.Field(6), wantColumn: "pk_read_only", wantFlags: FieldFlagPrimaryKey | FieldFlagReadOnly, wantOk: true}, + {name: "pk_read_only", structField: st.Field(6), wantColumn: "pk_read_only", wantFlags: FlagPrimaryKey | FlagReadOnly, wantOk: true}, {name: "no_flag", structField: st.Field(7), wantColumn: "no_flag", wantFlags: 0, wantOk: true}, - {name: "malformed_flags", structField: st.Field(8), wantColumn: "malformed_flags", wantFlags: FieldFlagReadOnly, wantOk: true}, + {name: "malformed_flags", structField: st.Field(8), wantColumn: "malformed_flags", wantFlags: FlagReadOnly, wantOk: true}, {name: "Embedded", structField: st.Field(9), wantColumn: "", wantFlags: 0, wantOk: true}, } for _, tt := range tests { diff --git a/reflection/structmapper.go b/reflection/structmapper.go new file mode 100644 index 0000000..4210ab4 --- /dev/null +++ b/reflection/structmapper.go @@ -0,0 +1,9 @@ +package reflection + +import ( + "reflect" +) + +type StructMapper interface { + ReflectStructMapping(t reflect.Type) (*StructMapping, error) +} diff --git a/reflection/structmapping.go b/reflection/structmapping.go new file mode 100644 index 0000000..e13be7a --- /dev/null +++ b/reflection/structmapping.go @@ -0,0 +1,188 @@ +package reflection + +import ( + "errors" + "fmt" + "reflect" + "sync" +) + +type StructMapping struct { + StructType reflect.Type + Table string + Columns []*StructColumn + ColumnMap map[string]*StructColumn +} + +type StructColumn struct { + Name string + Flags StructFieldFlags + FieldIndex []int + FieldType reflect.StructField +} + +type mappingKey struct { + reflect.Type + StructMapper +} + +var ( + cachedMappings = make(map[mappingKey]*StructMapping) + cachedMappingsMtx sync.Mutex +) + +func CachedStructMapping(t reflect.Type, m StructMapper) (*StructMapping, error) { + cachedMappingsMtx.Lock() + defer cachedMappingsMtx.Unlock() + + key := mappingKey{t, m} + + if mapping, ok := cachedMappings[key]; ok { + return mapping, nil + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %s is not a struct", t) + } + if m == nil { + return nil, errors.New("passed nil StructMapper") + } + mapping, err := m.ReflectStructMapping(t) + if err != nil { + return nil, err + } + cachedMappings[key] = mapping + return mapping, nil +} + +func (m *StructMapping) StructColumnValues(strct any, filter ColumnFilter) ([]any, error) { + v := reflect.ValueOf(strct) + switch v.Kind() { + case reflect.Struct: + // ok + case reflect.Pointer: + if v.IsNil() { + return nil, fmt.Errorf("passed nil %T", strct) + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %T is not a struct pointer", strct) + } + default: + return nil, fmt.Errorf("passed type %T is not a struct or struct pointer", strct) + } + if v.Type() != m.StructType { + return nil, fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) + } + + if filter == nil { + vals := make([]any, len(m.Columns)) + for i, col := range m.Columns { + vals[i] = v.FieldByIndex(col.FieldIndex).Interface() + } + return vals, nil + } + + vals := make([]any, 0, len(m.Columns)) + for _, col := range m.Columns { + val := v.FieldByIndex(col.FieldIndex) + if !filter.IgnoreColumn(col, val) { + vals = append(vals, val.Interface()) + } + } + return vals, nil +} + +// func (m *StructMapping) StructColumnPointers(structPtr any, filter ColumnFilter) ([]any, error) { +// v := reflect.ValueOf(structPtr) +// if v.Kind() != reflect.Pointer { +// return nil, fmt.Errorf("passed type %T is not a struct pointer", structPtr) +// } +// if v.IsNil() { +// return nil, fmt.Errorf("passed nil %T", structPtr) +// } +// v = v.Elem() +// if v.Kind() != reflect.Struct { +// return nil, fmt.Errorf("passed type %T is not a struct pointer", structPtr) +// } +// if v.Type() != m.StructType { +// return nil, fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) +// } + +// if filter == nil { +// vals := make([]any, len(m.Columns)) +// for i, col := range m.Columns { +// vals[i] = v.FieldByIndex(col.FieldIndex).Addr().Interface() +// } +// return vals, nil +// } + +// vals := make([]any, 0, len(m.Columns)) +// for _, col := range m.Columns { +// val := v.FieldByIndex(col.FieldIndex) +// if !filter.IgnoreColumn(col, val) { +// vals = append(vals, val.Addr().Interface()) +// } +// } +// return vals, nil +// } + +// ScanStruct scans values of a srcRow into a destStruct which must be passed as pointer. +func (m *StructMapping) ScanStruct(srcRow Row, structPtr any, filter ColumnFilter) error { + v := reflect.ValueOf(structPtr) + // if v.Kind() != reflect.Pointer { + // return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + // } + // if v.IsNil() { + // return fmt.Errorf("passed nil %T", structPtr) + // } + // v = v.Elem() + // if v.Kind() != reflect.Struct { + // return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + // } + // if v.Type() != m.StructType { + // return fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) + // } + + var ( + setDestStructPtr = false + destStructPtr reflect.Value + newStructPtr reflect.Value + ) + if v.Kind() == reflect.Ptr && v.IsNil() && v.CanSet() { + // Got a nil pointer that we can set with a newly allocated struct + setDestStructPtr = true + destStructPtr = v + newStructPtr = reflect.New(v.Type().Elem()) + // Continue with the newly allocated struct + v = newStructPtr.Elem() + } + if v.Kind() != reflect.Struct { + return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + } + + columns, err := srcRow.Columns() + if err != nil { + return err + } + + fieldPointers := make([]any, len(columns)) + for i, name := range columns { + col, ok := m.ColumnMap[name] + if !ok { + return fmt.Errorf("no mapping for column %s to struct %s", name, m.StructType) + } + fieldPointers[i] = v.FieldByIndex(col.FieldIndex).Addr().Interface() + } + + err = srcRow.Scan(fieldPointers...) + if err != nil { + return err + } + + if setDestStructPtr { + destStructPtr.Set(newStructPtr) + } + + return nil +} diff --git a/reflection/taggedstructmapping.go b/reflection/taggedstructmapping.go new file mode 100644 index 0000000..f2a666d --- /dev/null +++ b/reflection/taggedstructmapping.go @@ -0,0 +1,142 @@ +package reflection + +import ( + "fmt" + "reflect" + "strings" +) + +// TaggedStructMapper implements StructFieldMapper with a struct field NameTag +// to be used for naming and a UntaggedNameFunc in case the NameTag is not set. +type TaggedStructMapper struct { + _Named_Fields_Required struct{} + + // NameTag is the struct field tag to be used as column name + NameTag string + + // Ignore will cause a struct field to be ignored if it has that name + Ignore string + + PrimaryKey string + ReadOnly string + Default string + + // UntaggedNameFunc will be called with the struct field name to + // return a column name in case the struct field has no tag named NameTag. + UntaggedNameFunc func(fieldName string) string +} + +// NewTaggedStructMapper returns a default mapping. +func NewTaggedStructMapper() *TaggedStructMapper { + return &TaggedStructMapper{ + NameTag: "db", + Ignore: "-", + PrimaryKey: "pk", + ReadOnly: "readonly", + Default: "default", + UntaggedNameFunc: IgnoreStructField, + } +} + +func (m *TaggedStructMapper) ReflectStructMapping(structType reflect.Type) (*StructMapping, error) { + if structType.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %s is not a struct", structType) + } + mapping := &StructMapping{ + StructType: structType, + ColumnMap: make(map[string]*StructColumn), + } + err := m.reflectStructMapping(structType, mapping) + if err != nil { + return nil, err + } + return mapping, nil +} + +func (m *TaggedStructMapper) reflectStructMapping(structType reflect.Type, mapping *StructMapping) error { + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + fieldTable, name, flags, use := m.mapStructField(field) + if !use { + continue + } + + if name == "" { + // Embedded struct field + err := m.reflectStructMapping(field.Type, mapping) + if err != nil { + return err + } + continue + } + + if fieldTable != "" && fieldTable != mapping.Table { + if mapping.Table != "" { + return fmt.Errorf("table name not unique (%s vs %s) in struct %s", mapping.Table, fieldTable, mapping.StructType) + } + mapping.Table = fieldTable + } + + column := &StructColumn{ + Name: name, + Flags: flags, + FieldIndex: field.Index, + FieldType: field, + } + mapping.Columns = append(mapping.Columns, column) + mapping.ColumnMap[name] = column + } + return nil +} + +func (m *TaggedStructMapper) mapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) { + if field.Anonymous { + column, hasTag := field.Tag.Lookup(m.NameTag) + if !hasTag { + // Embedded struct fields are ok if not tagged with IgnoreName + return "", "", 0, true + } + if i := strings.IndexByte(column, ','); i != -1 { + column = column[:i] + } + // Embedded struct fields are ok if not tagged with IgnoreName + return "", "", 0, column != m.Ignore + } + + if !field.IsExported() { + // Not exported struct fields that are not + // anonymously embedded structs are not ok + return "", "", 0, false + } + + tag, hasTag := field.Tag.Lookup(m.NameTag) + if hasTag { + for i, part := range strings.Split(tag, ",") { + // First part is the name + if i == 0 { + column = part + continue + } + // Follow on parts are flags + flag, value, _ := strings.Cut(part, "=") + switch flag { + case "": + // Ignore empty flags + case m.PrimaryKey: + flags |= FlagPrimaryKey + table = value + case m.ReadOnly: + flags |= FlagReadOnly + case m.Default: + flags |= FlagHasDefault + } + } + } else if m.UntaggedNameFunc != nil { + column = m.UntaggedNameFunc(field.Name) + } + + if column == m.Ignore || column == "" { + return "", "", 0, false + } + return table, column, flags, true +} diff --git a/row.go b/row.go index 352cc80..fd70693 100644 --- a/row.go +++ b/row.go @@ -23,17 +23,17 @@ type Row interface { // RowWithError returns a dummy Row // where all methods return the passed error. func RowWithError(err error) Row { - return rowWithError{err} + return errRow{err} } -type rowWithError struct{ err error } +type errRow struct{ err error } -func (e rowWithError) Columns() ([]string, error) { return nil, e.err } -func (e rowWithError) Scan(dest ...any) error { return e.err } +func (e errRow) Columns() ([]string, error) { return nil, e.err } +func (e errRow) Scan(dest ...any) error { return e.err } /////////////////////////////////////////////////////////////////////////////// -type rowWrapper struct { +type sqlRow struct { ctx context.Context // ctx is checked for every row and passed through to callbacks rows *sql.Rows conn Connection // for error wrapping @@ -42,10 +42,10 @@ type rowWrapper struct { } func NewRow(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Row { - return &rowWrapper{ctx, rows, conn, query, args} + return &sqlRow{ctx, rows, conn, query, args} } -func (r *rowWrapper) Columns() ([]string, error) { +func (r *sqlRow) Columns() ([]string, error) { columns, err := r.rows.Columns() if err != nil { return nil, WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) @@ -53,7 +53,7 @@ func (r *rowWrapper) Columns() ([]string, error) { return columns, nil } -func (r *rowWrapper) Scan(dest ...any) (err error) { +func (r *sqlRow) Scan(dest ...any) (err error) { defer func() { err = combineTwoErrors(err, r.rows.Close()) if err != nil { diff --git a/rows.go b/rows.go index 82d7570..350320e 100644 --- a/rows.go +++ b/rows.go @@ -22,17 +22,17 @@ type Rows interface { // RowsWithError returns dummy Rows // where all methods return the passed error. func RowsWithError(err error) Rows { - return rowsWithError{err} + return errRows{err} } -type rowsWithError struct{ err error } +type errRows struct{ err error } -func (e rowsWithError) ForEachRow(func(Row) error) error { return e.err } -func (e rowsWithError) Close() error { return e.err } +func (e errRows) ForEachRow(func(Row) error) error { return e.err } +func (e errRows) Close() error { return e.err } /////////////////////////////////////////////////////////////////////////////// -type rowsWrapper struct { +type sqlRows struct { ctx context.Context // ctx is checked for every row and passed through to callbacks rows *sql.Rows conn Connection // for error wrapping @@ -41,10 +41,10 @@ type rowsWrapper struct { } func NewRows(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Rows { - return &rowsWrapper{ctx, rows, conn, query, args} + return &sqlRows{ctx, rows, conn, query, args} } -func (r *rowsWrapper) ForEachRow(callback func(Row) error) (err error) { +func (r *sqlRows) ForEachRow(callback func(Row) error) (err error) { defer func() { err = combineTwoErrors(err, r.rows.Close()) if err != nil { @@ -65,7 +65,7 @@ func (r *rowsWrapper) ForEachRow(callback func(Row) error) (err error) { return r.rows.Err() } -func (r *rowsWrapper) Close() error { +func (r *sqlRows) Close() error { return r.rows.Close() } diff --git a/sqliteconn/connection.go b/sqliteconn/connection.go new file mode 100644 index 0000000..a74d3a5 --- /dev/null +++ b/sqliteconn/connection.go @@ -0,0 +1,27 @@ +package sqliteconn + +import ( + "context" + "fmt" + + _ "modernc.org/sqlite" + + "github.com/domonda/go-sqldb" +) + +// New creates a new sqldb.Connection using the passed sqldb.Config +// and modernc.org/sqlite as driver implementation. +// The connection is pinged with the passed context +// and only returned when there was no error from the ping. +func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { + if config.Driver != "sqlite" { + return nil, fmt.Errorf(`invalid driver %q, pqconn expects "sqlite"`, config.Driver) + } + + db, err := config.Connect(ctx) + if err != nil { + return nil, err + } + _ = db + panic("TODO") +} From fc38e8ab0b8def1749c8dd410b793bb29f5cc4a7 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Thu, 4 Aug 2022 15:13:08 +0200 Subject: [PATCH 12/12] db.QueryStruct with first pkValue arg --- db/query.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/db/query.go b/db/query.go index 03137a4..d7b45fd 100644 --- a/db/query.go +++ b/db/query.go @@ -63,14 +63,14 @@ func QueryValueOrDefault[T any](ctx context.Context, query string, args ...any) return value, err } -// QueryStruct uses the passed pkValues to query a table row +// QueryStruct uses the passed pkValue+pkValues to query a table row // and scan it into a struct of type S that must have tagged fields // with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. -func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - if len(pkValues) == 0 { - return nil, errors.New("missing primary key values") - } +// for the passed pkValue+pkValues and a table name. +func QueryStruct[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + // Using explicit first pkValue value + // to not be able to compile without any value + pkValues = append([]any{pkValue}, pkValues...) t := reflect.TypeOf(row).Elem() if t.Kind() != reflect.Struct { return nil, fmt.Errorf("expected struct template type instead of %s", t) @@ -94,14 +94,14 @@ func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error return row, nil } -// QueryStructOrNil uses the passed pkValues to query a table row +// QueryStructOrNil uses the passed pkValue+pkValues to query a table row // and scan it into a struct of type S that must have tagged fields // with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. +// for the passed pkValue+pkValues and a table name. // Returns nil as row and error if no row could be found with the -// passed pkValues. -func QueryStructOrNil[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - row, err = QueryStruct[S](ctx, pkValues...) +// passed pkValue+pkValues. +func QueryStructOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + row, err = QueryStruct[S](ctx, pkValue, pkValues...) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil