Skip to content

Commit

Permalink
CBG-4179 check early for context cancellation (#7079)
Browse files Browse the repository at this point in the history
* CBG-4179 check early for context cancellation

* clarify options.ChangesCtx

* Switch to use ChangesOptions.ChangesCtx

* do not expect an error message ChangesEntry.Err if context is cancelled

- ChangesFeed.Err is checked in a few places
    - GenerateChanges checks entry.Err and only serves to exit the
      function, it won't return that entry. If there is no end entry, it
      will hit <-options.ChangesCtx.Done() and set forceClose.
    - TestChannelQueryCancellation checked the error but this was
      unnecssary since the assertions are that the changes feed exits
      with only one ViewQuery
- drop error return from generateBlipSyncChanges since it was unused
- move DatabaseContext.GetChanges test only function to test only code
  since it has particular use cases.
     - require that there is no error so the caller doesn't have to
       check since it was only used for TestChannelQueryCancellation
- slight behavior change to set forceClose in
  GenerateChanges/sendSimpleChanges to check
  for context cancellation in the case that the ChangesEntry.Err was hit
  first, so this marks this as forceClose. forceClose is used to call
  DatabaseContext.NotifyTerminatedChanges to wake other changes feeds.

* add force close
  • Loading branch information
torcolvin authored Sep 19, 2024
1 parent be71c9c commit 7b5a401
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 110 deletions.
2 changes: 1 addition & 1 deletion db/blip_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func (bh *blipHandler) sendChanges(sender *blip.Sender, opts *sendChangesOptions

}

_, forceClose := generateBlipSyncChanges(bh.loggingCtx, changesDb, channelSet, options, opts.docIDs, func(changes []*ChangeEntry) error {
forceClose := generateBlipSyncChanges(bh.loggingCtx, changesDb, channelSet, options, opts.docIDs, func(changes []*ChangeEntry) error {
base.DebugfCtx(bh.loggingCtx, base.KeySync, " Sending %d changes", len(changes))
for _, change := range changes {
if !strings.HasPrefix(change.ID, "_") {
Expand Down
45 changes: 24 additions & 21 deletions db/change_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,7 @@ func TestChannelCacheBackfill(t *testing.T) {
collectionWithUser, ctx := GetSingleDatabaseCollectionWithUser(ctx, t, db)
collectionWithUser.user, err = authenticator.GetUser("naomi")
require.NoError(t, err)
changes, err := collectionWithUser.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t))
assert.NoError(t, err, "Couldn't GetChanges")
assert.Len(t, changes, 4)
changes := getChanges(t, collectionWithUser, base.SetOf("*"), getChangesOptionsWithZeroSeq(t))

collectionID := collection.GetCollectionID()

Expand Down Expand Up @@ -547,8 +545,7 @@ func TestChannelCacheBackfill(t *testing.T) {

// verify changes has three entries (needs to resend all since previous LowSeq, which
// will be the late arriver (3) along with 5, 6)
changes, err = collectionWithUser.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq))
require.NoError(t, err)
changes = getChanges(t, collectionWithUser, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq))
assert.Len(t, changes, 3)
assert.Equal(t, &ChangeEntry{
Seq: SequenceID{Seq: 3, LowSeq: 3},
Expand Down Expand Up @@ -939,8 +936,7 @@ func TestChannelQueryCancellation(t *testing.T) {
options.Continuous = false
options.Wait = false
options.Limit = 2 // Avoid prepending results in cache, as we don't want second changes to serve results from cache
_, err := collection.GetChanges(ctx, base.SetOf("ABC"), options)
assert.NoError(t, err, "Expect no error for first changes request")
_ = getChanges(t, collection, base.SetOf("ABC"), options)
}()

// Wait for queryBlocked=true - ensures the first goroutine has acquired view lock
Expand All @@ -964,8 +960,7 @@ func TestChannelQueryCancellation(t *testing.T) {
options.Continuous = false
options.Limit = 2
options.Wait = false
_, err := collection.GetChanges(ctx, base.SetOf("ABC"), options)
assert.Error(t, err, "Expected error for second changes")
_ = getChanges(t, collection, base.SetOf("ABC"), options)
}()

// wait for second goroutine to be queued for the view lock (based on expvar)
Expand Down Expand Up @@ -1245,8 +1240,7 @@ func TestChannelCacheSize(t *testing.T) {
collectionWithUser, ctx := GetSingleDatabaseCollectionWithUser(ctx, t, db)
collectionWithUser.user, err = authenticator.GetUser("naomi")
require.NoError(t, err)
changes, err := collectionWithUser.GetChanges(ctx, base.SetOf("ABC"), getChangesOptionsWithZeroSeq(t))
assert.NoError(t, err, "Couldn't GetChanges")
changes := getChanges(t, collectionWithUser, base.SetOf("ABC"), getChangesOptionsWithZeroSeq(t))
assert.Len(t, changes, 750)

// Validate that cache stores the expected number of values
Expand Down Expand Up @@ -1533,8 +1527,7 @@ func TestInitializeEmptyCache(t *testing.T) {
}

// Issue getChanges for empty channel
changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
assert.NoError(t, err, "Couldn't GetChanges")
changes := getChanges(t, collection, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
changesCount := len(changes)
assert.Equal(t, 0, changesCount)

Expand All @@ -1551,10 +1544,8 @@ func TestInitializeEmptyCache(t *testing.T) {
cacheWaiter.Add(docCount)
cacheWaiter.Wait()

changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
assert.NoError(t, err, "Couldn't GetChanges")
changesCount = len(changes)
assert.Equal(t, 10, changesCount)
changes = getChanges(t, collection, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
assert.Len(t, changes, 10)
}

// Trigger initialization of the channel cache under load via getChanges. Ensures validFrom handling correctly
Expand Down Expand Up @@ -1598,8 +1589,7 @@ func TestInitializeCacheUnderLoad(t *testing.T) {

// Wait for writes to be in progress, then getChanges for channel zero
writesInProgress.Wait()
changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
require.NoError(t, err, "Couldn't GetChanges")
changes := getChanges(t, collection, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t))
firstChangesCount := len(changes)
var lastSeq SequenceID
if firstChangesCount > 0 {
Expand All @@ -1609,8 +1599,7 @@ func TestInitializeCacheUnderLoad(t *testing.T) {
// Wait for all writes to be cached, then getChanges again
cacheWaiter.Wait()

changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithSeq(t, lastSeq))
require.NoError(t, err, "Couldn't GetChanges")
changes = getChanges(t, collection, channels.BaseSetOf(t, "zero"), getChangesOptionsWithSeq(t, lastSeq))
secondChangesCount := len(changes)
assert.Equal(t, docCount, firstChangesCount+secondChangesCount)

Expand Down Expand Up @@ -2863,6 +2852,20 @@ func TestReleasedSequenceRangeHandlingDuplicateSequencesInSkipped(t *testing.T)
}, time.Second*10, time.Millisecond*100)
}

// getChanges is a synchronous convenience function that returns all changes as a simple array. This will fail the test if an error is returned.
func getChanges(t *testing.T, collection *DatabaseCollectionWithUser, channels base.Set, options ChangesOptions) []*ChangeEntry {
require.NotNil(t, options.ChangesCtx)
feed, err := collection.MultiChangesFeed(options.ChangesCtx, channels, options)

require.NoError(t, err)
require.NotNil(t, feed)
var changes = make([]*ChangeEntry, 0, 50)
for entry := range feed {
changes = append(changes, entry)
}
return changes
}

// TestAddPendingLogs:
// - Test age-based eviction of sequences and ranges from pending logs.
// - Adds to pending logs directly via heap.Push with backdated TimeReceived,
Expand Down
111 changes: 62 additions & 49 deletions db/changes.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ func (db *DatabaseCollectionWithUser) buildRevokedFeed(ctx context.Context, ch c
// 3. An error is returned when calling singleChannelCache.GetChanges
// 4. An error is returned when calling wasDocInChannelPriorToRevocation
for {
if options.ChangesCtx.Err() != nil {
base.DebugfCtx(ctx, base.KeyChanges, "Terminating revocation channel feed %s", base.UD(to))
return
}

if requestLimit == 0 {
paginationOptions.Limit = queryLimit
} else {
Expand All @@ -241,6 +246,11 @@ func (db *DatabaseCollectionWithUser) buildRevokedFeed(ctx context.Context, ch c

sentChanges := 0
for _, logEntry := range changes {
if options.ChangesCtx.Err() != nil {
base.DebugfCtx(ctx, base.KeyChanges, "Terminating revocation channel feed %s", base.UD(to))
return
}

seqID := SequenceID{
Seq: logEntry.Sequence,
TriggeredBy: revokedAt,
Expand Down Expand Up @@ -409,6 +419,10 @@ func (db *DatabaseCollectionWithUser) changesFeed(ctx context.Context, singleCha
// 2. A limit is specified on the incoming ChangesOptions, and that limit is reached
// 3. An error is returned when calling singleChannelCache.GetChanges
for {
if options.ChangesCtx.Err() != nil {
base.DebugfCtx(ctx, base.KeyChanges, "Terminating channel feed %s", base.UD(to))
return
}
// Calculate limit for this iteration
if requestLimit == 0 {
paginationOptions.Limit = queryLimit
Expand All @@ -432,6 +446,10 @@ func (db *DatabaseCollectionWithUser) changesFeed(ctx context.Context, singleCha
// Now write each log entry to the 'feed' channel in turn:
sentChanges := 0
for _, logEntry := range changes {
if options.ChangesCtx.Err() != nil {
base.DebugfCtx(ctx, base.KeyChanges, "Terminating channel feed %s", base.UD(to))
return
}
if logEntry.Sequence >= options.Since.TriggeredBy {
options.Since.TriggeredBy = 0
}
Expand Down Expand Up @@ -758,7 +776,10 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
if err != nil {
base.WarnfCtx(ctx, "Unable to obtain channel cache for %s, terminating feed", base.UD(chanName))
change := makeErrorEntry("Channel cache unavailable, terminating feed")
output <- &change
select {
case output <- &change:
case <-options.ChangesCtx.Done():
}
return
}

Expand All @@ -769,7 +790,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
if useLateSequenceFeeds {
lateSequenceFeedHandler := lateSequenceFeeds[chanID]
if lateSequenceFeedHandler != nil {
latefeed, err := col.getLateFeed(lateSequenceFeedHandler, singleChannelCache)
latefeed, err := col.getLateFeed(options.ChangesCtx, lateSequenceFeedHandler, singleChannelCache)
if err != nil {
base.WarnfCtx(ctx, "MultiChangesFeed got error reading late sequence feed %q, rolling back channel changes feed to last sent low sequence #%d.", base.UD(chanName), lastSentLowSeq)
chanOpts.Since.LowSeq = lastSentLowSeq
Expand Down Expand Up @@ -889,7 +910,10 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
// On feed error, send the error and exit changes processing
if current[i].Err == base.ErrChannelFeed {
base.WarnfCtx(ctx, "MultiChangesFeed got error reading changes feed: %v", current[i].Err)
output <- current[i]
select {
case <-options.ChangesCtx.Done():
case output <- current[i]:
}
return
}
}
Expand Down Expand Up @@ -967,6 +991,9 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
// Send the entry, and repeat the loop:
base.DebugfCtx(ctx, base.KeyChanges, "MultiChangesFeed sending %+v %s", base.UD(minEntry), base.UD(to))

if options.ChangesCtx.Err() != nil {
return
}
select {
case <-options.ChangesCtx.Done():
return
Expand Down Expand Up @@ -1001,7 +1028,11 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
// If nothing found, and in wait mode: wait for the db to change, then run again.
// First notify the reader that we're waiting by sending a nil.
base.DebugfCtx(ctx, base.KeyChanges, "MultiChangesFeed waiting... %s", base.UD(to))
output <- nil
select {
case <-options.ChangesCtx.Done():
return
case output <- nil:
}

// If this is an initial replication using CBL 2.x (active only), flip activeOnly now the client has caught up.
if options.clientType == clientTypeCBL2 && options.ActiveOnly {
Expand All @@ -1027,22 +1058,12 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
waitResponse := changeWaiter.Wait(ctx)
col.dbStats().CBLReplicationPull().NumPullReplCaughtUp.Add(-1)

if waitResponse == WaiterClosed {
if options.ChangesCtx.Err() != nil {
return
} else if waitResponse == WaiterClosed {
break outer
} else if waitResponse == WaiterHasChanges {
select {
case <-options.ChangesCtx.Done():
return
default:
break waitForChanges
}
} else if waitResponse == WaiterCheckTerminated {
// Check whether I was terminated while waiting for a change. If not, resume wait.
select {
case <-options.ChangesCtx.Done():
return
default:
}
break waitForChanges
}
}
// Update the current max cached sequence for the next changes iteration
Expand All @@ -1054,7 +1075,10 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex
if err != nil {
change := makeErrorEntry("User not found during reload - terminating changes feed")
base.DebugfCtx(ctx, base.KeyChanges, "User not found during reload - terminating changes feed with entry %+v", base.UD(change))
output <- &change
select {
case <-options.ChangesCtx.Done():
case output <- &change:
}
return
}
if userChanged && col.user != nil {
Expand Down Expand Up @@ -1099,28 +1123,6 @@ func (col *DatabaseCollectionWithUser) waitForCacheUpdate(ctx context.Context, c
return false
}

// FOR TEST USE ONLY: Synchronous convenience function that returns all changes as a simple array,
// Returns error if initial feed creation fails, or if an error is returned with the changes entries
func (db *DatabaseCollectionWithUser) GetChanges(ctx context.Context, channels base.Set, options ChangesOptions) ([]*ChangeEntry, error) {
if options.ChangesCtx == nil {
changesCtx, changesCtxCancel := context.WithCancel(context.Background())
options.ChangesCtx = changesCtx
defer changesCtxCancel()
}

var changes = make([]*ChangeEntry, 0, 50)
feed, err := db.MultiChangesFeed(ctx, channels, options)
if err == nil && feed != nil {
for entry := range feed {
if entry.Err != nil {
err = entry.Err
}
changes = append(changes, entry)
}
}
return changes, err
}

// Returns the set of cached log entries for a given channel
func (c *DatabaseCollection) GetChangeLog(ctx context.Context, channel channels.ID, afterSeq uint64) (entries []*LogEntry, err error) {
return c.changeCache().getChannelCache().GetCachedChanges(ctx, channel)
Expand Down Expand Up @@ -1177,7 +1179,7 @@ func (db *DatabaseCollectionWithUser) newLateSequenceFeed(singleChannelCache Sin

// Feed to process late sequences for the channel. Updates lastSequence as it works the feed. Error indicates
// previous position in late sequence feed isn't available, and caller should reset to low sequence.
func (db *DatabaseCollectionWithUser) getLateFeed(feedHandler *lateSequenceFeed, singleChannelCache SingleChannelCache) (<-chan *ChangeEntry, error) {
func (db *DatabaseCollectionWithUser) getLateFeed(ctx context.Context, feedHandler *lateSequenceFeed, singleChannelCache SingleChannelCache) (<-chan *ChangeEntry, error) {

if !singleChannelCache.SupportsLateFeed() {
return nil, errors.New("Cache doesn't support late feeds")
Expand Down Expand Up @@ -1217,7 +1219,12 @@ func (db *DatabaseCollectionWithUser) getLateFeed(feedHandler *lateSequenceFeed,
Seq: logEntry.Sequence,
}
change := makeChangeEntry(logEntry, seqID, singleChannelCache.ChannelID())
feed <- &change
select {
case <-ctx.Done():
return

case feed <- &change:
}
}
}()

Expand Down Expand Up @@ -1354,30 +1361,31 @@ func (options ChangesOptions) String() string {
}

// Used by BLIP connections for changes. Supports both one-shot and continuous changes.
func generateBlipSyncChanges(ctx context.Context, database *DatabaseCollectionWithUser, inChannels base.Set, options ChangesOptions, docIDFilter []string, send func([]*ChangeEntry) error) (err error, forceClose bool) {
func generateBlipSyncChanges(ctx context.Context, database *DatabaseCollectionWithUser, inChannels base.Set, options ChangesOptions, docIDFilter []string, send func([]*ChangeEntry) error) (forceClose bool) {

// Store one-shot here to protect
isOneShot := !options.Continuous
err, forceClose = GenerateChanges(ctx, options.ChangesCtx, database, inChannels, options, docIDFilter, send)
err, forceClose := GenerateChanges(ctx, database, inChannels, options, docIDFilter, send)

if _, ok := err.(*ChangesSendErr); ok {
return nil, forceClose // error is probably because the client closed the connection
// If there was already an error in a send function, do not send last one shot changes message, since it probably will not work anyway.
return forceClose // error is probably because the client closed the connection
}

// For one-shot changes, invoke the callback w/ nil to trigger the 'caught up' changes message. (For continuous changes, this
// is done by MultiChangesFeed prior to going into Wait mode)
if isOneShot {
_ = send(nil)
}
return err, forceClose
return forceClose
}

type ChangesSendErr struct{ error }

// Shell of the continuous changes feed -- calls out to a `send` function to deliver the change.
// This is called from BLIP connections as well as HTTP handlers, which is why this is not a
// method on `handler`.
func GenerateChanges(ctx context.Context, cancelCtx context.Context, database *DatabaseCollectionWithUser, inChannels base.Set, options ChangesOptions, docIDFilter []string, send func([]*ChangeEntry) error) (err error, forceClose bool) {
func GenerateChanges(ctx context.Context, database *DatabaseCollectionWithUser, inChannels base.Set, options ChangesOptions, docIDFilter []string, send func([]*ChangeEntry) error) (err error, forceClose bool) {
// Set up heartbeat/timeout
var timeoutInterval time.Duration
var timer *time.Timer
Expand Down Expand Up @@ -1508,7 +1516,7 @@ loop:
case <-database.exitChanges():
forceClose = true
break loop
case <-cancelCtx.Done():
case <-options.ChangesCtx.Done():
forceClose = true
break loop
}
Expand All @@ -1518,5 +1526,10 @@ loop:
}
}

// if the ChangesCtx is done, the connection was force closed. This could actually happen and send a ChangeEntry.Err. Instead of checking each place in this function, set the forceClose flag here.
if options.ChangesCtx.Err() != nil {
forceClose = true
}

return nil, forceClose
}
Loading

0 comments on commit 7b5a401

Please sign in to comment.