Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PostgreSQL protocol for deallocating prepared statements #1797

Merged
merged 6 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,26 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
return sd, nil
}

// Deallocate releases a prepared statement.
// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
func (c *Conn) Deallocate(ctx context.Context, name string) error {
var psName string
if sd, ok := c.preparedStatements[name]; ok {
delete(c.preparedStatements, name)
sd := c.preparedStatements[name]
if sd != nil {
psName = sd.Name
} else {
psName = name
}
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(psName)).ReadAll()
return err

err := c.pgConn.Deallocate(ctx, psName)
if err != nil {
return err
}

if sd != nil {
delete(c.preparedStatements, name)
}

return nil
}

// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
Expand Down
63 changes: 63 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,69 @@ func TestPrepareWithDigestedName(t *testing.T) {
})
}

// https://github.com/jackc/pgx/pull/1795
func TestDeallocateInAbortedTransaction(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
tx, err := conn.Begin(ctx)
require.NoError(t, err)

sql := "select $1::text"
sd, err := tx.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)

var s string
err = tx.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)

_, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error
require.Error(t, err)

err = conn.Deallocate(ctx, sql)
require.NoError(t, err)

err = tx.Rollback(ctx)
require.NoError(t, err)

sd, err = conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
})
}

func TestDeallocateMissingPreparedStatementStillClearsFromPreparedStatementMap(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Prepare(ctx, "ps", "select $1::text")
require.NoError(t, err)

_, err = conn.Exec(ctx, "deallocate ps")
require.NoError(t, err)

err = conn.Deallocate(ctx, "ps")
require.NoError(t, err)

_, err = conn.Prepare(ctx, "ps", "select $1::text, $2::text")
require.NoError(t, err)

var s1, s2 string
err = conn.QueryRow(ctx, "ps", "hello", "world").Scan(&s1, &s2)
require.NoError(t, err)
require.Equal(t, "hello", s1)
require.Equal(t, "world", s2)
})
}

func TestListenNotify(t *testing.T) {
t.Parallel()

Expand Down
49 changes: 49 additions & 0 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ type StatementDescription struct {

// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
// allows Prepare to also to describe statements without creating a server-side prepared statement.
//
// Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages
// directly.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, err
Expand Down Expand Up @@ -869,6 +872,52 @@ readloop:
return psd, nil
}

// Deallocate deallocates a prepared statement.
//
// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message
// directly. This has slightly different behavior than executing DEALLOCATE statement.
// - Deallocate can succeed in an aborted transaction.
// - Deallocating a non-existent prepared statement is not an error.
func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()

if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}

pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return err
}

for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return normalizeTimeoutError(ctx, err)
}

switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
return nil
}
}
}

// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Expand Down
83 changes: 83 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,89 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
ensureConnValid(t, pgConn)
}

func TestConnDeallocate(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
require.NoError(t, err)

_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
require.NoError(t, err)

err = pgConn.Deallocate(ctx, "ps1")
require.NoError(t, err)

_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
require.Error(t, err)
var pgErr *pgconn.PgError
require.ErrorAs(t, err, &pgErr)
require.Equal(t, "26000", pgErr.Code)

ensureConnValid(t, pgConn)
}

func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

err = pgConn.Exec(ctx, "begin").Close()
require.NoError(t, err)

_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
require.NoError(t, err)

_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
require.NoError(t, err)

err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error
require.Error(t, err)
var pgErr *pgconn.PgError
require.ErrorAs(t, err, &pgErr)
require.Equal(t, "22012", pgErr.Code)

err = pgConn.Deallocate(ctx, "ps1")
require.NoError(t, err)

err = pgConn.Exec(ctx, "rollback").Close()
require.NoError(t, err)

_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
require.Error(t, err)
require.ErrorAs(t, err, &pgErr)
require.Equal(t, "26000", pgErr.Code)

ensureConnValid(t, pgConn)
}

func TestConnDeallocateNonExistantStatementSucceeds(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

err = pgConn.Deallocate(ctx, "ps1")
require.NoError(t, err)

ensureConnValid(t, pgConn)
}

func TestConnExec(t *testing.T) {
t.Parallel()

Expand Down