Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit abd5fe9

Browse files
committed
chore(): small code update
1 parent 3062ed5 commit abd5fe9

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

example/simple/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func main() {
2121
log.Fatalf("Unmarshal: %v", err)
2222
}
2323

24-
pgStream, err := pglogicalstream.NewPgStream(config, log.WithPrefix("pg-cdc"))
24+
pgStream, err := pglogicalstream.NewPgStream(config)
2525
if err != nil {
2626
panic(err)
2727
}

logical_stream.go

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"crypto/tls"
77
"database/sql"
88
"encoding/json"
9+
"errors"
910
"fmt"
1011
"os"
1112
"strings"
13+
"sync"
1214
"time"
1315

1416
"github.com/charmbracelet/log"
@@ -24,9 +26,11 @@ type Stream struct {
2426
pgConn *pgconn.PgConn
2527
// extra copy of db config is required to establish a new db connection
2628
// which is required to take snapshot data
27-
dbConfig pgconn.Config
28-
ctx context.Context
29-
cancel context.CancelFunc
29+
dbConfig pgconn.Config
30+
streamCtx context.Context
31+
streamCancel context.CancelFunc
32+
33+
standbyCtxCancel context.CancelFunc
3034
clientXLogPos pglogrepl.LSN
3135
standbyMessageTimeout time.Duration
3236
nextStandbyMessageDeadline time.Time
@@ -42,9 +46,12 @@ type Stream struct {
4246
snapshotBatchSize int
4347
snapshotMemorySafetyFactor float64
4448
logger *log.Logger
49+
50+
m sync.Mutex
51+
stopped bool
4552
}
4653

47-
func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
54+
func NewPgStream(config Config) (*Stream, error) {
4855
var (
4956
cfg *pgconn.Config
5057
err error
@@ -97,7 +104,9 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
97104
snapshotBatchSize: config.BatchSize,
98105
tableNames: tableNames,
99106
changeFilter: NewChangeFilter(tableNames, config.DbSchema),
100-
logger: logger,
107+
logger: log.WithPrefix("[pg-stream]"),
108+
m: sync.Mutex{},
109+
stopped: false,
101110
}
102111

103112
result := stream.pgConn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS pglog_stream_%s;", config.ReplicationSlotName))
@@ -106,7 +115,6 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
106115
stream.logger.Errorf("drop publication if exists error %s", err.Error())
107116
}
108117

109-
// TODO:: ADD Tables filter
110118
for i, table := range tableNames {
111119
tableNames[i] = fmt.Sprintf("%s.%s", config.DbSchema, table)
112120
}
@@ -170,7 +178,7 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
170178

171179
stream.standbyMessageTimeout = time.Second * 10
172180
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)
173-
stream.ctx, stream.cancel = context.WithCancel(context.Background())
181+
stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background())
174182

175183
if !freshlyCreatedSlot || config.StreamOldData == false {
176184
stream.startLr()
@@ -215,8 +223,8 @@ func (s *Stream) AckLSN(lsn string) {
215223
func (s *Stream) streamMessagesAsync() {
216224
for {
217225
select {
218-
case <-s.ctx.Done():
219-
s.cancel()
226+
case <-s.streamCtx.Done():
227+
s.logger.Warn("Stream was cancelled...exiting...")
220228
return
221229
default:
222230
if time.Now().After(s.nextStandbyMessageDeadline) {
@@ -234,11 +242,18 @@ func (s *Stream) streamMessagesAsync() {
234242

235243
ctx, cancel := context.WithDeadline(context.Background(), s.nextStandbyMessageDeadline)
236244
rawMsg, err := s.pgConn.ReceiveMessage(ctx)
237-
s.cancel = cancel
245+
s.standbyCtxCancel = cancel
246+
247+
if err != nil && (errors.Is(err, context.Canceled) || s.stopped) {
248+
s.logger.Warn("Service was interrpupted....stop reading from replication slot")
249+
return
250+
}
251+
238252
if err != nil {
239253
if pgconn.Timeout(err) {
240254
continue
241255
}
256+
242257
s.logger.Fatalf("Failed to receive messages from PostgreSQL %s", err.Error())
243258
}
244259

@@ -288,12 +303,12 @@ func (s *Stream) streamMessagesAsync() {
288303
func (s *Stream) processSnapshot() {
289304
snapshotter, err := NewSnapshotter(s.dbConfig, s.snapshotName)
290305
if err != nil {
291-
s.logger.Errorf("Failed to create database snapshot: %", err.Error())
306+
s.logger.Errorf("Failed to create database snapshot: %v", err.Error())
292307
s.cleanUpOnFailure()
293308
os.Exit(1)
294309
}
295310
if err = snapshotter.Prepare(); err != nil {
296-
s.logger.Errorf("Failed to prepare database snapshot: %", err.Error())
311+
s.logger.Errorf("Failed to prepare database snapshot: %v", err.Error())
297312
s.cleanUpOnFailure()
298313
os.Exit(1)
299314
}
@@ -303,17 +318,16 @@ func (s *Stream) processSnapshot() {
303318
}()
304319

305320
for _, table := range s.tableNames {
306-
log.Printf("Processing snapshot for a table %s", table)
321+
s.logger.Info("Processing snapshot for table", "table", table)
307322

308323
var (
309324
avgRowSizeBytes sql.NullInt64
310325
offset = int(0)
311326
)
312327
avgRowSizeBytes = snapshotter.FindAvgRowSize(table)
313-
fmt.Println(avgRowSizeBytes, offset, "AVG SIZES")
314328

315329
batchSize := snapshotter.CalculateBatchSize(helpers.GetAvailableMemory(), uint64(avgRowSizeBytes.Int64))
316-
fmt.Println("Query with batch size", batchSize, "Available memory: ", helpers.GetAvailableMemory(), "Avg row size: ", avgRowSizeBytes.Int64)
330+
s.logger.Info("Querying snapshot", "batch_side", batchSize, "available_memory", helpers.GetAvailableMemory(), "avg_row_size", avgRowSizeBytes.Int64)
317331

318332
tablePk, err := s.getPrimaryKeyColumn(table)
319333
if err != nil {
@@ -427,7 +441,7 @@ func (s *Stream) OnMessage(callback OnMessage) {
427441
callback(snapshotMessage)
428442
case message := <-s.messages:
429443
callback(message)
430-
case <-s.ctx.Done():
444+
case <-s.streamCtx.Done():
431445
return
432446
}
433447
}
@@ -473,12 +487,16 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) {
473487
}
474488

475489
func (s *Stream) Stop() error {
490+
s.m.Lock()
491+
s.stopped = true
492+
s.m.Unlock()
493+
476494
if s.pgConn != nil {
477-
if s.ctx != nil {
478-
s.cancel()
495+
if s.streamCtx != nil {
496+
s.streamCancel()
497+
s.standbyCtxCancel()
479498
}
480-
481-
return s.pgConn.Close(context.TODO())
499+
return s.pgConn.Close(context.Background())
482500
}
483501

484502
return nil

pglogrepl.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ func CreateReplicationSlot(
233233
snapshotString = options.SnapshotAction
234234
}
235235
sql := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString)
236-
fmt.Println(sql)
237236
return ParseCreateReplicationSlot(conn.Exec(ctx, sql))
238237
}
239238

snapshotter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ package pglogicalstream
33
import (
44
"database/sql"
55
"fmt"
6-
"log"
6+
7+
"github.com/charmbracelet/log"
78

89
"github.com/jackc/pgx/v5/pgconn"
910
_ "github.com/lib/pq"
@@ -25,8 +26,6 @@ func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, er
2526
dbConf.Password, dbConf.Host, dbConf.Port, dbConf.Database, sslMode,
2627
)
2728

28-
fmt.Println("Conn string", connStr)
29-
3029
pgConn, err := sql.Open("postgres", connStr)
3130

3231
return &Snapshotter{
@@ -76,7 +75,8 @@ func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSiz
7675
}
7776

7877
func (s *Snapshotter) QuerySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) {
79-
fmt.Println("Query snapshot: ", fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset))
78+
// fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)
79+
log.WithPrefix("[pg-stream/snapshotter]").Info("Query snapshot", "table", table, "limit", limit, "offset", offset, "pk", pk)
8080
return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset))
8181
}
8282

0 commit comments

Comments
 (0)