Skip to content

Commit

Permalink
Introduce DB#InsertObtainID() method
Browse files Browse the repository at this point in the history
  • Loading branch information
yhabteab committed May 28, 2024
1 parent a66da14 commit 89805fa
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions database/contracts.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package database

import (
"context"
"github.com/jmoiron/sqlx"
)

// Entity is implemented by each type that works with the database package.
type Entity interface {
Fingerprinter
Expand Down Expand Up @@ -54,3 +59,11 @@ type PgsqlOnConflictConstrainter interface {
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
PgsqlOnConflictConstraint() string
}

// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance.
type TxOrDB interface {
sqlx.ExtContext
sqlx.PreparerContext

PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
45 changes: 45 additions & 0 deletions database/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package database

import (
"context"
"database/sql"
"database/sql/driver"
"github.com/go-sql-driver/mysql"
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/strcase"
"github.com/icinga/icinga-go-library/types"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -43,6 +45,49 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
}
}

// InsertObtainID executes the given query and fetches the last inserted ID.
//
// Using this method for database tables that don't define an auto-incrementing ID, or none at all,
// will not work. The only supported column that can be retrieved with this method is id.
// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance.
// Returns the retrieved ID wrapped in [types.Int] on success and error on any database inserting/retrieving failure.
func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (types.Int, error) {
var resultID int64
switch conn.DriverName() {
case PostgreSQL:
query := stmt + " RETURNING id"
ps, err := conn.PrepareNamedContext(ctx, query)
if err != nil {
return types.Int{}, errors.Wrapf(err, "cannot prepare %q", query)
}
// We're deferring the ps#Close invocation here just to be on the safe side, otherwise it's
// closed manually later on and the error is handled gracefully (if any).
defer func() { _ = ps.Close() }()

if err = ps.GetContext(ctx, &resultID, arg); err != nil {
return types.Int{}, CantPerformQuery(err, query)
}

if err = ps.Close(); err != nil {
return types.Int{}, errors.Wrapf(err, "cannot close prepared statement %q", query)
}
case MySQL:
result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg)
if err != nil {
return types.Int{}, CantPerformQuery(err, stmt)
}

resultID, err = result.LastInsertId()
if err != nil {
return types.Int{}, errors.Wrap(err, "cannot retrieve last inserted ID")
}
default:
return types.Int{}, errors.Errorf("unsupported driver: %s", conn.DriverName())
}

return types.Int{NullInt64: sql.NullInt64{Int64: resultID, Valid: true}}, nil
}

// setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed
// before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key
// violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes,
Expand Down

0 comments on commit 89805fa

Please sign in to comment.