diff --git a/HISTORY.md b/HISTORY.md index 015fadd9..0f418e74 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,8 @@ +## v7.3.1 + +* [ADDED] Exposed `goqu.NewTx` to allow creating a goqu tx directly from a `sql.Tx` instead of using `goqu.Database#Begin` [#95](https://github.com/doug-martin/goqu/issues/95) +* [ADDED] `goqu.Database.BeginTx` [#98](https://github.com/doug-martin/goqu/issues/98) + ## v7.3.0 * [ADDED] UPDATE and INSERT should use struct Field name if db tag is not specified [#57](https://github.com/doug-martin/goqu/issues/57) diff --git a/database.go b/database.go index da270d7c..7e457b47 100644 --- a/database.go +++ b/database.go @@ -15,6 +15,7 @@ type ( // libraries such as sqlx instead of the native sql.DB SQLDatabase interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) @@ -69,11 +70,24 @@ func (d *Database) Dialect() string { // Starts a new Transaction. func (d *Database) Begin() (*TxDatabase, error) { - tx, err := d.Db.Begin() + sqlTx, err := d.Db.Begin() if err != nil { return nil, err } - return &TxDatabase{dialect: d.dialect, Tx: tx, logger: d.logger}, nil + tx := NewTx(d.dialect, sqlTx) + tx.Logger(d.logger) + return tx, nil +} + +// Starts a new Transaction. See sql.DB#BeginTx for option description +func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabase, error) { + sqlTx, err := d.Db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + tx := NewTx(d.dialect, sqlTx) + tx.Logger(d.logger) + return tx, nil } // Creates a new Dataset that uses the correct adapter and supports queries. @@ -385,14 +399,29 @@ func (d *Database) ScanValContext(ctx context.Context, i interface{}, query stri // A wrapper around a sql.Tx and works the same way as Database type ( + // Interface for sql.Tx, an interface is used so you can use with other + // libraries such as sqlx instead of the native sql.DB + SQLTx interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Commit() error + Rollback() error + } TxDatabase struct { logger Logger dialect string - Tx *sql.Tx + Tx SQLTx qf exec.QueryFactory } ) +// Creates a new TxDatabase +func NewTx(dialect string, tx SQLTx) *TxDatabase { + return &TxDatabase{dialect: dialect, Tx: tx} +} + // returns this databases dialect func (td *TxDatabase) Dialect() string { return td.dialect diff --git a/database_example_test.go b/database_example_test.go index dd39f021..7878dcf5 100644 --- a/database_example_test.go +++ b/database_example_test.go @@ -2,6 +2,7 @@ package goqu_test import ( "context" + "database/sql" "fmt" "time" @@ -40,6 +41,39 @@ func ExampleDatabase_Begin() { // Updated users in transaction [ids:=[1 2 3]] } +func ExampleDatabase_BeginTx() { + db := getDb() + + ctx := context.Background() + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != nil { + fmt.Println("Error starting transaction", err.Error()) + } + + // use tx.From to get a dataset that will execute within this transaction + update := tx.From("goqu_user"). + Where(goqu.Ex{"last_name": "Yukon"}). + Returning("id"). + Update(goqu.Record{"last_name": "Ucon"}) + + var ids []int64 + if err := update.ScanVals(&ids); err != nil { + if rErr := tx.Rollback(); rErr != nil { + fmt.Println("An error occurred while issuing ROLLBACK\n\t", rErr.Error()) + } else { + fmt.Println("An error occurred while updating users ROLLBACK transaction\n\t", err.Error()) + } + return + } + if err := tx.Commit(); err != nil { + fmt.Println("An error occurred while issuing COMMIT\n\t", err.Error()) + } else { + fmt.Printf("Updated users in transaction [ids:=%+v]", ids) + } + // Output: + // Updated users in transaction [ids:=[1 2 3]] +} + func ExampleDatabase_Dialect() { db := getDb() diff --git a/database_test.go b/database_test.go index 5096cbcd..31152363 100644 --- a/database_test.go +++ b/database_test.go @@ -1,6 +1,7 @@ package goqu import ( + "context" "fmt" "testing" @@ -250,6 +251,22 @@ func (dt *databaseTest) TestBegin() { assert.EqualError(t, err, "goqu: transaction error") } +func (dt *databaseTest) TestBeginTx() { + t := dt.T() + ctx := context.Background() + mDb, mock, err := sqlmock.New() + assert.NoError(t, err) + mock.ExpectBegin() + mock.ExpectBegin().WillReturnError(errors.New("transaction error")) + db := New("mock", mDb) + tx, err := db.BeginTx(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, tx.Dialect(), "mock") + + _, err = db.BeginTx(ctx, nil) + assert.EqualError(t, err, "goqu: transaction error") +} + func TestDatabaseSuite(t *testing.T) { suite.Run(t, new(databaseTest)) }