@@ -6,9 +6,11 @@ import (
6
6
"crypto/tls"
7
7
"database/sql"
8
8
"encoding/json"
9
+ "errors"
9
10
"fmt"
10
11
"os"
11
12
"strings"
13
+ "sync"
12
14
"time"
13
15
14
16
"github.com/charmbracelet/log"
@@ -24,9 +26,11 @@ type Stream struct {
24
26
pgConn * pgconn.PgConn
25
27
// extra copy of db config is required to establish a new db connection
26
28
// which is required to take snapshot data
27
- dbConfig pgconn.Config
28
- ctx context.Context
29
- cancel context.CancelFunc
29
+ dbConfig pgconn.Config
30
+ streamCtx context.Context
31
+ streamCancel context.CancelFunc
32
+
33
+ standbyCtxCancel context.CancelFunc
30
34
clientXLogPos pglogrepl.LSN
31
35
standbyMessageTimeout time.Duration
32
36
nextStandbyMessageDeadline time.Time
@@ -42,9 +46,12 @@ type Stream struct {
42
46
snapshotBatchSize int
43
47
snapshotMemorySafetyFactor float64
44
48
logger * log.Logger
49
+
50
+ m sync.Mutex
51
+ stopped bool
45
52
}
46
53
47
- func NewPgStream (config Config , logger * log. Logger ) (* Stream , error ) {
54
+ func NewPgStream (config Config ) (* Stream , error ) {
48
55
var (
49
56
cfg * pgconn.Config
50
57
err error
@@ -97,7 +104,9 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
97
104
snapshotBatchSize : config .BatchSize ,
98
105
tableNames : tableNames ,
99
106
changeFilter : NewChangeFilter (tableNames , config .DbSchema ),
100
- logger : logger ,
107
+ logger : log .WithPrefix ("[pg-stream]" ),
108
+ m : sync.Mutex {},
109
+ stopped : false ,
101
110
}
102
111
103
112
result := stream .pgConn .Exec (context .Background (), fmt .Sprintf ("DROP PUBLICATION IF EXISTS pglog_stream_%s;" , config .ReplicationSlotName ))
@@ -106,7 +115,6 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
106
115
stream .logger .Errorf ("drop publication if exists error %s" , err .Error ())
107
116
}
108
117
109
- // TODO:: ADD Tables filter
110
118
for i , table := range tableNames {
111
119
tableNames [i ] = fmt .Sprintf ("%s.%s" , config .DbSchema , table )
112
120
}
@@ -170,7 +178,7 @@ func NewPgStream(config Config, logger *log.Logger) (*Stream, error) {
170
178
171
179
stream .standbyMessageTimeout = time .Second * 10
172
180
stream .nextStandbyMessageDeadline = time .Now ().Add (stream .standbyMessageTimeout )
173
- stream .ctx , stream .cancel = context .WithCancel (context .Background ())
181
+ stream .streamCtx , stream .streamCancel = context .WithCancel (context .Background ())
174
182
175
183
if ! freshlyCreatedSlot || config .StreamOldData == false {
176
184
stream .startLr ()
@@ -215,8 +223,8 @@ func (s *Stream) AckLSN(lsn string) {
215
223
func (s * Stream ) streamMessagesAsync () {
216
224
for {
217
225
select {
218
- case <- s .ctx .Done ():
219
- s .cancel ( )
226
+ case <- s .streamCtx .Done ():
227
+ s .logger . Warn ( "Stream was cancelled...exiting..." )
220
228
return
221
229
default :
222
230
if time .Now ().After (s .nextStandbyMessageDeadline ) {
@@ -234,11 +242,18 @@ func (s *Stream) streamMessagesAsync() {
234
242
235
243
ctx , cancel := context .WithDeadline (context .Background (), s .nextStandbyMessageDeadline )
236
244
rawMsg , err := s .pgConn .ReceiveMessage (ctx )
237
- s .cancel = cancel
245
+ s .standbyCtxCancel = cancel
246
+
247
+ if err != nil && (errors .Is (err , context .Canceled ) || s .stopped ) {
248
+ s .logger .Warn ("Service was interrpupted....stop reading from replication slot" )
249
+ return
250
+ }
251
+
238
252
if err != nil {
239
253
if pgconn .Timeout (err ) {
240
254
continue
241
255
}
256
+
242
257
s .logger .Fatalf ("Failed to receive messages from PostgreSQL %s" , err .Error ())
243
258
}
244
259
@@ -288,12 +303,12 @@ func (s *Stream) streamMessagesAsync() {
288
303
func (s * Stream ) processSnapshot () {
289
304
snapshotter , err := NewSnapshotter (s .dbConfig , s .snapshotName )
290
305
if err != nil {
291
- s .logger .Errorf ("Failed to create database snapshot: %" , err .Error ())
306
+ s .logger .Errorf ("Failed to create database snapshot: %v " , err .Error ())
292
307
s .cleanUpOnFailure ()
293
308
os .Exit (1 )
294
309
}
295
310
if err = snapshotter .Prepare (); err != nil {
296
- s .logger .Errorf ("Failed to prepare database snapshot: %" , err .Error ())
311
+ s .logger .Errorf ("Failed to prepare database snapshot: %v " , err .Error ())
297
312
s .cleanUpOnFailure ()
298
313
os .Exit (1 )
299
314
}
@@ -303,17 +318,16 @@ func (s *Stream) processSnapshot() {
303
318
}()
304
319
305
320
for _ , table := range s .tableNames {
306
- log . Printf ("Processing snapshot for a table %s " , table )
321
+ s . logger . Info ("Processing snapshot for table" , "table " , table )
307
322
308
323
var (
309
324
avgRowSizeBytes sql.NullInt64
310
325
offset = int (0 )
311
326
)
312
327
avgRowSizeBytes = snapshotter .FindAvgRowSize (table )
313
- fmt .Println (avgRowSizeBytes , offset , "AVG SIZES" )
314
328
315
329
batchSize := snapshotter .CalculateBatchSize (helpers .GetAvailableMemory (), uint64 (avgRowSizeBytes .Int64 ))
316
- fmt . Println ( "Query with batch size" , batchSize , "Available memory: " , helpers .GetAvailableMemory (), "Avg row size: " , avgRowSizeBytes .Int64 )
330
+ s . logger . Info ( "Querying snapshot" , "batch_side" , batchSize , "available_memory " , helpers .GetAvailableMemory (), "avg_row_size " , avgRowSizeBytes .Int64 )
317
331
318
332
tablePk , err := s .getPrimaryKeyColumn (table )
319
333
if err != nil {
@@ -427,7 +441,7 @@ func (s *Stream) OnMessage(callback OnMessage) {
427
441
callback (snapshotMessage )
428
442
case message := <- s .messages :
429
443
callback (message )
430
- case <- s .ctx .Done ():
444
+ case <- s .streamCtx .Done ():
431
445
return
432
446
}
433
447
}
@@ -473,12 +487,16 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) {
473
487
}
474
488
475
489
func (s * Stream ) Stop () error {
490
+ s .m .Lock ()
491
+ s .stopped = true
492
+ s .m .Unlock ()
493
+
476
494
if s .pgConn != nil {
477
- if s .ctx != nil {
478
- s .cancel ()
495
+ if s .streamCtx != nil {
496
+ s .streamCancel ()
497
+ s .standbyCtxCancel ()
479
498
}
480
-
481
- return s .pgConn .Close (context .TODO ())
499
+ return s .pgConn .Close (context .Background ())
482
500
}
483
501
484
502
return nil
0 commit comments