diff --git a/api/account.go b/api/account.go index e1bf645..196d2e3 100644 --- a/api/account.go +++ b/api/account.go @@ -1,14 +1,12 @@ package api import ( - "database/sql" "errors" "net/http" db "github.com/Samudra-G/simplebank/db/sqlc" "github.com/Samudra-G/simplebank/token" "github.com/gin-gonic/gin" - "github.com/lib/pq" ) type createAccountRequest struct { @@ -31,12 +29,10 @@ func (server *Server) createAccount(ctx *gin.Context) { account, err := server.store.CreateAccount(ctx, arg) if err != nil { - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "foreign_key_violation", "unique_violation": - ctx.JSON(http.StatusForbidden, errorResponse(err)) - return - } + errCode := db.ErrorCode(err) + if errCode == db.ForeignKeyViolation || errCode == db.UniqueViolation{ + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return @@ -57,7 +53,7 @@ func (server *Server) getAccount(ctx *gin.Context) { account, err := server.store.GetAccount(ctx, req.ID) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { ctx.JSON(http.StatusNotFound, errorResponse(err)) return } diff --git a/api/account_test.go b/api/account_test.go index 164fd36..0ed266c 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -88,7 +88,7 @@ func TestGetAccountAPI(t *testing.T){ store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). Times(1). - Return(db.Account{}, sql.ErrNoRows) + Return(db.Account{}, db.ErrRecordNotFound) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusNotFound, recorder.Code) diff --git a/api/user.go b/api/user.go index 2ac29bc..e1ccd6e 100644 --- a/api/user.go +++ b/api/user.go @@ -9,7 +9,6 @@ import ( "github.com/Samudra-G/simplebank/util" "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/lib/pq" ) type createUserRequest struct { @@ -57,12 +56,9 @@ func (server *Server) createUser(ctx *gin.Context) { user, err := server.store.CreateUser(ctx, arg) if err != nil { - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "unique_violation": - ctx.JSON(http.StatusForbidden, errorResponse(err)) - return - } + if db.ErrorCode(err) == db.UniqueViolation{ + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return diff --git a/api/user_test.go b/api/user_test.go index 65b0be7..0d80944 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -16,7 +16,6 @@ import ( "github.com/Samudra-G/simplebank/util" "github.com/gin-gonic/gin" "github.com/golang/mock/gomock" - "github.com/lib/pq" "github.com/stretchr/testify/require" ) @@ -118,7 +117,7 @@ func TestCreateUserAPI(t *testing.T) { store.EXPECT(). CreateUser(gomock.Any(), gomock.Any()). Times(1). - Return(db.User{}, &pq.Error{Code: "23505"}) + Return(db.User{},db.ErrUniqueViolation) }, checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { require.Equal(t, http.StatusForbidden, rec.Code) diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 28bb07e..659211d 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -22,7 +22,7 @@ type AddAccountBalanceParams struct { } func (q *Queries) AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) { - row := q.db.QueryRowContext(ctx, addAccountBalance, arg.Amount, arg.ID) + row := q.db.QueryRow(ctx, addAccountBalance, arg.Amount, arg.ID) var i Account err := row.Scan( &i.ID, @@ -52,7 +52,7 @@ type CreateAccountParams struct { } func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) { - row := q.db.QueryRowContext(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency) + row := q.db.QueryRow(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency) var i Account err := row.Scan( &i.ID, @@ -70,7 +70,7 @@ WHERE id = $1 ` func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { - _, err := q.db.ExecContext(ctx, deleteAccount, id) + _, err := q.db.Exec(ctx, deleteAccount, id) return err } @@ -80,7 +80,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { - row := q.db.QueryRowContext(ctx, getAccount, id) + row := q.db.QueryRow(ctx, getAccount, id) var i Account err := row.Scan( &i.ID, @@ -99,7 +99,7 @@ FOR NO KEY UPDATE ` func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) { - row := q.db.QueryRowContext(ctx, getAccountForUpdate, id) + row := q.db.QueryRow(ctx, getAccountForUpdate, id) var i Account err := row.Scan( &i.ID, @@ -126,7 +126,7 @@ type ListAccountsParams struct { } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.QueryContext(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) + rows, err := q.db.Query(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } @@ -145,9 +145,6 @@ func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]A } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -167,7 +164,7 @@ type UpdateAccountParams struct { } func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (Account, error) { - row := q.db.QueryRowContext(ctx, updateAccount, arg.ID, arg.Balance) + row := q.db.QueryRow(ctx, updateAccount, arg.ID, arg.Balance) var i Account err := row.Scan( &i.ID, diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index e7556ad..beded0b 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -2,7 +2,6 @@ package sqlc import ( "context" - "database/sql" "testing" "time" @@ -19,7 +18,7 @@ func createRandomAccount(t *testing.T) Account{ Currency: util.RandomCurrency(), } - account, err := testQueries.CreateAccount(context.Background(), arg) + account, err := testStore.CreateAccount(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, account) @@ -38,7 +37,7 @@ func TestCreateAccount(t *testing.T) { func TestGetAccount(t *testing.T) { account1 := createRandomAccount(t) - account2, err := testQueries.GetAccount(context.Background(), account1.ID) + account2, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) require.NotEmpty(t, account2) @@ -59,7 +58,7 @@ func TestUpdateAccount(t *testing.T) { Balance: util.RandomMoney(), } - account2, err := testQueries.UpdateAccount(context.Background(), arg) + account2, err := testStore.UpdateAccount(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, account2) @@ -74,12 +73,12 @@ func TestUpdateAccount(t *testing.T) { func TestDeleteAccount(t *testing.T) { account1 := createRandomAccount(t) - err := testQueries.DeleteAccount(context.Background(), account1.ID) + err := testStore.DeleteAccount(context.Background(), account1.ID) require.NoError(t, err) - account2, err := testQueries.GetAccount(context.Background(), account1.ID) + account2, err := testStore.GetAccount(context.Background(), account1.ID) require.Error(t, err) - require.EqualError(t, err, sql.ErrNoRows.Error()) + require.EqualError(t, err, ErrRecordNotFound.Error()) require.Empty(t, account2) } @@ -94,7 +93,7 @@ func TestListAccount(t *testing.T) { Offset: 0, } - accounts, err := testQueries.ListAccounts(context.Background(), arg) + accounts, err := testStore.ListAccounts(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, accounts) diff --git a/db/sqlc/db.go b/db/sqlc/db.go index e4d7828..2725108 100644 --- a/db/sqlc/db.go +++ b/db/sqlc/db.go @@ -6,14 +6,15 @@ package sqlc import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row } func New(db DBTX) *Queries { @@ -24,7 +25,7 @@ type Queries struct { db DBTX } -func (q *Queries) WithTx(tx *sql.Tx) *Queries { +func (q *Queries) WithTx(tx pgx.Tx) *Queries { return &Queries{ db: tx, } diff --git a/db/sqlc/entry.sql.go b/db/sqlc/entry.sql.go index 7ac7200..3ff6500 100644 --- a/db/sqlc/entry.sql.go +++ b/db/sqlc/entry.sql.go @@ -24,7 +24,7 @@ type CreateEntryParams struct { } func (q *Queries) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) { - row := q.db.QueryRowContext(ctx, createEntry, arg.AccountID, arg.Amount) + row := q.db.QueryRow(ctx, createEntry, arg.AccountID, arg.Amount) var i Entry err := row.Scan( &i.ID, @@ -41,7 +41,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetEntry(ctx context.Context, id int64) (Entry, error) { - row := q.db.QueryRowContext(ctx, getEntry, id) + row := q.db.QueryRow(ctx, getEntry, id) var i Entry err := row.Scan( &i.ID, @@ -67,7 +67,7 @@ type ListEntriesParams struct { } func (q *Queries) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) { - rows, err := q.db.QueryContext(ctx, listEntries, arg.AccountID, arg.Limit, arg.Offset) + rows, err := q.db.Query(ctx, listEntries, arg.AccountID, arg.Limit, arg.Offset) if err != nil { return nil, err } @@ -85,9 +85,6 @@ func (q *Queries) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Ent } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/db/sqlc/entry_test.go b/db/sqlc/entry_test.go index 4ffa6df..70a3d0d 100644 --- a/db/sqlc/entry_test.go +++ b/db/sqlc/entry_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/Samudra-G/simplebank/util" + "github.com/stretchr/testify/require" ) func createRandomEntry(t *testing.T, account Account) Entry { @@ -15,7 +15,7 @@ func createRandomEntry(t *testing.T, account Account) Entry { Amount: util.RandomMoney(), } - entry, err := testQueries.CreateEntry(context.Background(), arg) + entry, err := testStore.CreateEntry(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, entry) @@ -36,7 +36,7 @@ func TestCreateEntry(t *testing.T) { func TestGetEntry(t *testing.T) { account := createRandomAccount(t) entry1 := createRandomEntry(t, account) - entry2, err := testQueries.GetEntry(context.Background(), entry1.ID) + entry2, err := testStore.GetEntry(context.Background(), entry1.ID) require.NoError(t, err) require.NotEmpty(t, entry2) @@ -58,7 +58,7 @@ func TestListEntries(t *testing.T) { Offset: 5, } - entries, err := testQueries.ListEntries(context.Background(), arg) + entries, err := testStore.ListEntries(context.Background(), arg) require.NoError(t, err) require.Len(t, entries, 5) diff --git a/db/sqlc/error.go b/db/sqlc/error.go new file mode 100644 index 0000000..ee20bb2 --- /dev/null +++ b/db/sqlc/error.go @@ -0,0 +1,26 @@ +package sqlc + +import ( + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) +const ( + ForeignKeyViolation = "23503" + UniqueViolation = "23505" +) + +var ErrRecordNotFound = pgx.ErrNoRows +var ErrUniqueViolation = &pgconn.PgError{ + Code: UniqueViolation, +} + +func ErrorCode(err error) string { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code + } + + return "" +} \ No newline at end of file diff --git a/db/sqlc/exec_tx.go b/db/sqlc/exec_tx.go new file mode 100644 index 0000000..f6988bf --- /dev/null +++ b/db/sqlc/exec_tx.go @@ -0,0 +1,24 @@ +package sqlc + +import ( + "context" + "fmt" +) + +//execTx executes a function within a database transaction +func (store *SQLStore) execTx(ctx context.Context, fn func(*Queries) error) error { + tx, err := store.connPool.Begin(ctx) + if err != nil { + return err + } + q := New(tx) //New accepts dbtx interface + err = fn(q) + if err != nil { + if rbErr := tx.Rollback(ctx); rbErr != nil { + return fmt.Errorf("tx error: %v, rb err: %v", err, rbErr) + } + return err + } + + return tx.Commit(ctx) +} diff --git a/db/sqlc/main_test.go b/db/sqlc/main_test.go index 4dcc87c..483579f 100644 --- a/db/sqlc/main_test.go +++ b/db/sqlc/main_test.go @@ -1,28 +1,27 @@ package sqlc import ( - "database/sql" + "context" "log" "os" "testing" "github.com/Samudra-G/simplebank/util" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v5/pgxpool" ) -var testQueries *Queries -var testDB *sql.DB +var testStore Store func TestMain(m *testing.M) { config, err := util.LoadConfig("../..") if err != nil { log.Fatal("cannot load config: ", err) } - testDB, err = sql.Open(config.DBDriver, config.DBSource) + connPool, err := pgxpool.New(context.Background(), config.DBSource) if err != nil { log.Fatal("cannot connect to db: ", err) } - testQueries = New(testDB) + testStore = NewStore(connPool) os.Exit(m.Run()) } \ No newline at end of file diff --git a/db/sqlc/session.sql.go b/db/sqlc/session.sql.go index 4ba5e83..fded23d 100644 --- a/db/sqlc/session.sql.go +++ b/db/sqlc/session.sql.go @@ -38,7 +38,7 @@ type CreateSessionParams struct { } func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { - row := q.db.QueryRowContext(ctx, createSession, + row := q.db.QueryRow(ctx, createSession, arg.ID, arg.Username, arg.RefreshToken, @@ -67,7 +67,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (Session, error) { - row := q.db.QueryRowContext(ctx, getSession, id) + row := q.db.QueryRow(ctx, getSession, id) var i Session err := row.Scan( &i.ID, diff --git a/db/sqlc/store.go b/db/sqlc/store.go index f428499..fb49b52 100644 --- a/db/sqlc/store.go +++ b/db/sqlc/store.go @@ -2,8 +2,8 @@ package sqlc import ( "context" - "database/sql" - "fmt" + + "github.com/jackc/pgx/v5/pgxpool" ) //Store providse all functions for db queries and transaction @@ -16,32 +16,15 @@ type Store interface { //SQLStore provides all functions for sql queries and transaction type SQLStore struct { + connPool *pgxpool.Pool *Queries - db *sql.DB } //NewStore creates a new Store -func NewStore(db *sql.DB) Store { +func NewStore(connPool *pgxpool.Pool) Store { return &SQLStore{ - db: db, - Queries: New(db), + connPool: connPool, + Queries: New(connPool), } } -//execTx executes a function within a database transaction -func (store *SQLStore) execTx(ctx context.Context, fn func(*Queries) error) error { - tx, err := store.db.BeginTx(ctx, nil) - if err != nil { - return err - } - q := New(tx) //New accepts dbtx interface - err = fn(q) - if err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - return fmt.Errorf("tx error: %v, rb err: %v", err, rbErr) - } - return err - } - - return tx.Commit() -} diff --git a/db/sqlc/store_test.go b/db/sqlc/store_test.go index 5dcd2c2..627b0d5 100644 --- a/db/sqlc/store_test.go +++ b/db/sqlc/store_test.go @@ -9,7 +9,6 @@ import ( ) func TestTransferTx(t *testing.T) { - store := NewStore(testDB) account1 := createRandomAccount(t) account2 := createRandomAccount(t) @@ -25,7 +24,7 @@ func TestTransferTx(t *testing.T) { for i := 0; i < n; i++ { go func() { ctx := context.Background() - result, err := store.TransferTx(ctx, TransferTxParams{ + result, err := testStore.TransferTx(ctx, TransferTxParams{ FromAccountID: account1.ID, ToAccountID: account2.ID, Amount: amount, @@ -54,7 +53,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, transfer.ID) require.NotZero(t, transfer.CreatedAt) - _, err = store.GetTransfer(context.Background(), transfer.ID) + _, err = testStore.GetTransfer(context.Background(), transfer.ID) require.NoError(t, err) //check entries @@ -65,7 +64,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, fromEntry.ID) require.NotZero(t, fromEntry.CreatedAt) - _, err = store.GetEntry(context.Background(), fromEntry.ID) + _, err = testStore.GetEntry(context.Background(), fromEntry.ID) require.NoError(t, err) toEntry := result.ToEntry @@ -75,7 +74,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, toEntry.ID) require.NotZero(t, toEntry.CreatedAt) - _, err = store.GetEntry(context.Background(), toEntry.ID) + _, err = testStore.GetEntry(context.Background(), toEntry.ID) require.NoError(t, err) //check accounts @@ -102,10 +101,10 @@ func TestTransferTx(t *testing.T) { } //check final updated balances - updatedAccount1, err := testQueries.GetAccount(context.Background(), account1.ID) + updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) - updatedAccount2, err := testQueries.GetAccount(context.Background(), account2.ID) + updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) require.NoError(t, err) fmt.Println(">> after: ", updatedAccount1.Balance, updatedAccount2.Balance) @@ -114,8 +113,6 @@ func TestTransferTx(t *testing.T) { } func TestTransferTxDeadlock(t *testing.T) { - store := NewStore(testDB) - account1 := createRandomAccount(t) account2 := createRandomAccount(t) fmt.Println(">> before:", account1.Balance, account2.Balance) @@ -136,7 +133,7 @@ func TestTransferTxDeadlock(t *testing.T) { go func() { ctx := context.Background() - _, err := store.TransferTx(ctx, TransferTxParams{ + _, err := testStore.TransferTx(ctx, TransferTxParams{ FromAccountID: fromAccountID, ToAccountID: toAccountID, Amount: amount, @@ -153,10 +150,10 @@ func TestTransferTxDeadlock(t *testing.T) { } //check final updated balances - updatedAccount1, err := testQueries.GetAccount(context.Background(), account1.ID) + updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) - updatedAccount2, err := testQueries.GetAccount(context.Background(), account2.ID) + updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) require.NoError(t, err) fmt.Println(">> after: ", updatedAccount1.Balance, updatedAccount2.Balance) diff --git a/db/sqlc/transfer.sql.go b/db/sqlc/transfer.sql.go index ee8d5da..aeb0269 100644 --- a/db/sqlc/transfer.sql.go +++ b/db/sqlc/transfer.sql.go @@ -26,7 +26,7 @@ type CreateTransferParams struct { } func (q *Queries) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) { - row := q.db.QueryRowContext(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount) + row := q.db.QueryRow(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount) var i Transfer err := row.Scan( &i.ID, @@ -44,7 +44,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetTransfer(ctx context.Context, id int64) (Transfer, error) { - row := q.db.QueryRowContext(ctx, getTransfer, id) + row := q.db.QueryRow(ctx, getTransfer, id) var i Transfer err := row.Scan( &i.ID, @@ -74,7 +74,7 @@ type ListTransfersParams struct { } func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) { - rows, err := q.db.QueryContext(ctx, listTransfers, + rows, err := q.db.Query(ctx, listTransfers, arg.FromAccountID, arg.ToAccountID, arg.Limit, @@ -98,9 +98,6 @@ func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([ } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/db/sqlc/transfer_test.go b/db/sqlc/transfer_test.go index 7e58754..1b59c7c 100644 --- a/db/sqlc/transfer_test.go +++ b/db/sqlc/transfer_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/Samudra-G/simplebank/util" + "github.com/stretchr/testify/require" ) func createRandomTransfer(t *testing.T, account1, account2 Account) Transfer { @@ -16,7 +16,7 @@ func createRandomTransfer(t *testing.T, account1, account2 Account) Transfer { Amount: util.RandomMoney(), } - transfer, err := testQueries.CreateTransfer(context.Background(), arg) + transfer, err := testStore.CreateTransfer(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, transfer) @@ -41,7 +41,7 @@ func TestGetTransfer(t *testing.T) { account2 := createRandomAccount(t) transfer1 := createRandomTransfer(t, account1, account2) - transfer2, err := testQueries.GetTransfer(context.Background(), transfer1.ID) + transfer2, err := testStore.GetTransfer(context.Background(), transfer1.ID) require.NoError(t, err) require.NotEmpty(t, transfer2) @@ -68,7 +68,7 @@ func TestListTransfer(t *testing.T) { Offset: 5, } - transfers, err := testQueries.ListTransfers(context.Background(), arg) + transfers, err := testStore.ListTransfers(context.Background(), arg) require.NoError(t, err) require.Len(t, transfers, 5) diff --git a/db/sqlc/tx_verify_email.go b/db/sqlc/tx_verify_email.go index 45835b7..e6a3ab7 100644 --- a/db/sqlc/tx_verify_email.go +++ b/db/sqlc/tx_verify_email.go @@ -2,7 +2,8 @@ package sqlc import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5/pgtype" ) type VerifyEmailTxParams struct { @@ -31,7 +32,7 @@ func (store *SQLStore) VerifyEmailTx(ctx context.Context, arg VerifyEmailTxParam result.User, err = q.UpdateUser(ctx, UpdateUserParams{ Username: result.VerifyEmail.Username, - IsEmailVerified: sql.NullBool{ + IsEmailVerified: pgtype.Bool{ Bool: true, Valid: true, }, diff --git a/db/sqlc/user.sql.go b/db/sqlc/user.sql.go index 50786b4..4180a9c 100644 --- a/db/sqlc/user.sql.go +++ b/db/sqlc/user.sql.go @@ -7,7 +7,8 @@ package sqlc import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5/pgtype" ) const createUser = `-- name: CreateUser :one @@ -30,7 +31,7 @@ type CreateUserParams struct { } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, + row := q.db.QueryRow(ctx, createUser, arg.Username, arg.HashedPassword, arg.FullName, @@ -55,7 +56,7 @@ WHERE username = $1 LIMIT 1 ` func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { - row := q.db.QueryRowContext(ctx, getUser, username) + row := q.db.QueryRow(ctx, getUser, username) var i User err := row.Scan( &i.Username, @@ -83,16 +84,16 @@ RETURNING username, hashed_password, full_name, email, password_changed_at, crea ` type UpdateUserParams struct { - HashedPassword sql.NullString `json:"hashed_password"` - PasswordChangedAt sql.NullTime `json:"password_changed_at"` - FullName sql.NullString `json:"full_name"` - Email sql.NullString `json:"email"` - IsEmailVerified sql.NullBool `json:"is_email_verified"` - Username string `json:"username"` + HashedPassword pgtype.Text `json:"hashed_password"` + PasswordChangedAt pgtype.Timestamptz `json:"password_changed_at"` + FullName pgtype.Text `json:"full_name"` + Email pgtype.Text `json:"email"` + IsEmailVerified pgtype.Bool `json:"is_email_verified"` + Username string `json:"username"` } func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, updateUser, + row := q.db.QueryRow(ctx, updateUser, arg.HashedPassword, arg.PasswordChangedAt, arg.FullName, diff --git a/db/sqlc/user_test.go b/db/sqlc/user_test.go index 92cfc4d..f47f95f 100644 --- a/db/sqlc/user_test.go +++ b/db/sqlc/user_test.go @@ -2,11 +2,11 @@ package sqlc import ( "context" - "database/sql" "testing" "time" "github.com/Samudra-G/simplebank/util" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) @@ -21,7 +21,7 @@ func createRandomUser(t *testing.T) User{ Email: util.RandomEmail(), } - user, err := testQueries.CreateUser(context.Background(), arg) + user, err := testStore.CreateUser(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, user) @@ -41,7 +41,7 @@ func TestCreateUser(t *testing.T) { func TestGetUser(t *testing.T) { user1 := createRandomUser(t) - user2, err := testQueries.GetUser(context.Background(), user1.Username) + user2, err := testStore.GetUser(context.Background(), user1.Username) require.NoError(t, err) require.NotEmpty(t, user2) @@ -58,9 +58,9 @@ func TestUpdateUserOnlyFullName(t *testing.T) { oldUser := createRandomUser(t) newFullName := util.RandomOwner() - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newFullName, Valid: true, }, @@ -77,9 +77,9 @@ func TestUpdateUserOnlyEmail(t *testing.T) { oldUser := createRandomUser(t) newEmail := util.RandomEmail() - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, @@ -98,9 +98,9 @@ func TestUpdateUserOnlyHashedPassword(t *testing.T) { newPassword := util.RandomString(6) newHashedPassword, err := util.HashPassword(newPassword) require.NoError(t, err) - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - HashedPassword: sql.NullString{ + HashedPassword: pgtype.Text{ String: newHashedPassword, Valid: true, }, @@ -121,17 +121,17 @@ func TestUpdateUserOnlyAllFields(t *testing.T) { newPassword := util.RandomString(6) newHashedPassword, err := util.HashPassword(newPassword) require.NoError(t,err) - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newFullName, Valid: true, }, - HashedPassword: sql.NullString{ + HashedPassword: pgtype.Text{ String: newHashedPassword, Valid: true, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, diff --git a/db/sqlc/verify_email.sql.go b/db/sqlc/verify_email.sql.go index a0d7dd1..fddfc5e 100644 --- a/db/sqlc/verify_email.sql.go +++ b/db/sqlc/verify_email.sql.go @@ -28,7 +28,7 @@ type CreateVerifyEmailParams struct { } func (q *Queries) CreateVerifyEmail(ctx context.Context, arg CreateVerifyEmailParams) (VerifyEmail, error) { - row := q.db.QueryRowContext(ctx, createVerifyEmail, arg.Username, arg.Email, arg.SecretCode) + row := q.db.QueryRow(ctx, createVerifyEmail, arg.Username, arg.Email, arg.SecretCode) var i VerifyEmail err := row.Scan( &i.ID, @@ -60,7 +60,7 @@ type UpdateVerifyEmailParams struct { } func (q *Queries) UpdateVerifyEmail(ctx context.Context, arg UpdateVerifyEmailParams) (VerifyEmail, error) { - row := q.db.QueryRowContext(ctx, updateVerifyEmail, arg.ID, arg.SecretCode) + row := q.db.QueryRow(ctx, updateVerifyEmail, arg.ID, arg.SecretCode) var i VerifyEmail err := row.Scan( &i.ID, diff --git a/gapi/rpc_create_user.go b/gapi/rpc_create_user.go index 581ffba..dc4114e 100644 --- a/gapi/rpc_create_user.go +++ b/gapi/rpc_create_user.go @@ -10,7 +10,6 @@ import ( "github.com/Samudra-G/simplebank/val" "github.com/Samudra-G/simplebank/worker" "github.com/hibiken/asynq" - "github.com/lib/pq" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -51,11 +50,8 @@ func (server *Server)CreateUser(ctx context.Context, req *pb.CreateUserRequest) txResult, err := server.store.CreateUserTx(ctx, arg) if err != nil { - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "unique_violation": - return nil, status.Errorf(codes.AlreadyExists, "username already exists: %v", err) - } + if db.ErrorCode(err) == db.UniqueViolation { + return nil, status.Error(codes.AlreadyExists, err.Error()) } return nil, status.Errorf(codes.Internal, "failed to create user: %v", err) } diff --git a/gapi/rpc_create_user_test.go b/gapi/rpc_create_user_test.go index c0806c6..b07be71 100644 --- a/gapi/rpc_create_user_test.go +++ b/gapi/rpc_create_user_test.go @@ -15,7 +15,6 @@ import ( "github.com/Samudra-G/simplebank/worker" mockwk "github.com/Samudra-G/simplebank/worker/mock" "github.com/golang/mock/gomock" - "github.com/lib/pq" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -163,7 +162,7 @@ func TestCreateUserAPI(t *testing.T) { store.EXPECT(). CreateUserTx(gomock.Any(), gomock.Any()). Times(1). - Return(db.CreateUserTxResult{}, &pq.Error{Code: "23505"}) + Return(db.CreateUserTxResult{}, db.ErrUniqueViolation) taskDistributor.EXPECT(). DistributeTaskSendVerifyEmail(gomock.Any(), gomock.Any(), gomock.Any()). diff --git a/gapi/rpc_update_user.go b/gapi/rpc_update_user.go index d45327d..210c375 100644 --- a/gapi/rpc_update_user.go +++ b/gapi/rpc_update_user.go @@ -9,6 +9,7 @@ import ( "github.com/Samudra-G/simplebank/pb" "github.com/Samudra-G/simplebank/util" "github.com/Samudra-G/simplebank/val" + "github.com/jackc/pgx/v5/pgtype" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -31,11 +32,11 @@ func (server *Server)UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) arg := db.UpdateUserParams{ Username: req.GetUsername(), - FullName: sql.NullString{ + FullName: pgtype.Text{ String: req.GetFullName(), Valid: req.FullName != nil, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: req.GetEmail(), Valid: req.Email != nil, }, @@ -47,12 +48,12 @@ func (server *Server)UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) return nil, status.Errorf(codes.Internal, "failed to hash password: %v", err) } - arg.HashedPassword = sql.NullString{ + arg.HashedPassword = pgtype.Text{ String: hashedPassword, Valid: true, } - arg.PasswordChangedAt = sql.NullTime{ + arg.PasswordChangedAt = pgtype.Timestamptz{ Time: time.Now(), Valid: true, } diff --git a/gapi/rpc_update_user_test.go b/gapi/rpc_update_user_test.go index ad6fd03..bbec706 100644 --- a/gapi/rpc_update_user_test.go +++ b/gapi/rpc_update_user_test.go @@ -12,6 +12,7 @@ import ( "github.com/Samudra-G/simplebank/token" "github.com/Samudra-G/simplebank/util" "github.com/golang/mock/gomock" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -64,11 +65,11 @@ func TestUpdateUserAPI(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { arg := db.UpdateUserParams{ Username: user.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newName, Valid: true, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, diff --git a/go.mod b/go.mod index 68cbd4d..c542664 100644 --- a/go.mod +++ b/go.mod @@ -12,8 +12,8 @@ require ( github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 github.com/hibiken/asynq v0.25.1 + github.com/jackc/pgx/v5 v5.5.4 github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible - github.com/lib/pq v1.10.9 github.com/o1egl/paseto v1.0.0 github.com/rakyll/statik v0.1.7 github.com/rs/zerolog v1.34.0 @@ -44,9 +44,13 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -68,6 +72,7 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.18.0 // indirect golang.org/x/net v0.41.0 // indirect + golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.34.0 // indirect golang.org/x/text v0.26.0 // indirect golang.org/x/time v0.12.0 // indirect diff --git a/go.sum b/go.sum index ec48aa0..b97e258 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,14 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hibiken/asynq v0.25.1 h1:phj028N0nm15n8O2ims+IvJ2gz4k2auvermngh9JhTw= github.com/hibiken/asynq v0.25.1/go.mod h1:pazWNOLBu0FEynQRBvHA26qdIKRSmfdIfUm4HdsLmXg= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= +github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible h1:jdpOPRN1zP63Td1hDQbZW73xKmzDvZHzVdNYxhnTMDA= github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible/go.mod h1:1c7szIrayyPPB/987hsnvNzLushdWf4o/79s3P08L8A= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -214,6 +222,8 @@ golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index 391e961..b9cc2d2 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "database/sql" "net" "net/http" "os" @@ -20,7 +19,7 @@ import ( _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hibiken/asynq" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rakyll/statik/fs" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -39,14 +38,14 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } - conn, err := sql.Open(config.DBDriver, config.DBSource) + connPool, err := pgxpool.New(context.Background(), config.DBSource) if err != nil { log.Fatal().Err(err).Msg("cannot connect to db") } runDBMigration(config.MigrationURL, config.DBSource) - store := db.NewStore(conn) + store := db.NewStore(connPool) redisOpt := asynq.RedisClientOpt{ Addr: config.RedisAddress, diff --git a/sqlc.yaml b/sqlc.yaml index 4af81b6..c30de5d 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -6,7 +6,12 @@ sql: gen: go: out: "./db/sqlc" + sql_package: "pgx/v5" emit_json_tags: true emit_empty_slices: true emit_interface: true - \ No newline at end of file + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "uuid" + go_type: "github.com/google/uuid.UUID" \ No newline at end of file diff --git a/util/config.go b/util/config.go index ebc0efa..87fbc21 100644 --- a/util/config.go +++ b/util/config.go @@ -10,7 +10,6 @@ import ( // The values are read by viper from a config file or environment variables type Config struct { Environment string `mapstructure:"ENVIRONMENT"` - DBDriver string `mapstructure:"DB_DRIVER"` DBSource string `mapstructure:"DB_SOURCE"` MigrationURL string `mapstructure:"MIGRATION_URL"` RedisAddress string `mapstructure:"REDIS_ADDRESS"`