From fd0c65478e18be837b77c7ef24d7220f50540d49 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 13 Sep 2024 08:03:37 -0500 Subject: [PATCH] Fix prepared statement already exists on batch prepare failure When a batch successfully prepared some statements, but then failed to prepare others, the prepared statements that were successfully prepared were not properly cleaned up. This could lead to a "prepared statement already exists" error on subsequent attempts to prepare the same statement. https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 --- batch_test.go | 30 +++++++++++++++++++++ conn.go | 75 +++++++++++++++++++++++++++++++-------------------- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/batch_test.go b/batch_test.go index eb560e068..b1bc25de6 100644 --- a/batch_test.go +++ b/batch_test.go @@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.False(t, rows.Next()) } +// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 +func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + mustExec(t, conn, `create temporary table foo(col1 text primary key);`) + + batch := &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err := conn.SendBatch(ctx, batch).Close() + require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`) + + mustExec(t, conn, `create temporary table baz(col1 text primary key);`) + + // Since table baz now exists, the batch should succeed. + + batch = &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err = conn.SendBatch(ctx, batch).Close() + require.NoError(t, err) + }) +} + func ExampleConn_SendBatch() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() diff --git a/conn.go b/conn.go index 187b3dd57..1d4c414fb 100644 --- a/conn.go +++ b/conn.go @@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Prepare any needed queries if len(distinctNewQueries) > 0 { - for _, sd := range distinctNewQueries { - pipeline.SendPrepare(sd.Name, sd.SQL, nil) - } + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } - err := pipeline.Sync() - if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} - } + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() + + err = pipeline.Sync() + if err != nil { + return err + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return err + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } - for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + return err } - resultSD, ok := results.(*pgconn.StatementDescription) + _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + return fmt.Errorf("expected sync, got %T", results) } - // Fill in the previously empty / pending statement descriptions. - sd.ParamOIDs = resultSD.ParamOIDs - sd.Fields = resultSD.Fields - } - - results, err := pipeline.GetResults() + return nil + }() if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - - _, ok := results.(*pgconn.PipelineSync) - if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} - } - } - - // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. - if sdCache != nil { - for _, sd := range distinctNewQueries { - sdCache.Put(sd) - } } // Queue the queries.