diff --git a/lib/gcpspanner/client.go b/lib/gcpspanner/client.go index f6cdd12e5..045d372e4 100644 --- a/lib/gcpspanner/client.go +++ b/lib/gcpspanner/client.go @@ -81,6 +81,7 @@ type Client struct { featureSearchQuery FeatureSearchBaseQuery missingOneImplQuery MissingOneImplementationQuery searchCfg searchConfig + notificationCfg notificationConfig batchWriter batchSize int batchWriters int @@ -140,10 +141,17 @@ type searchConfig struct { maxBookmarksPerUser uint32 } +// notificationConfig holds the application configuation for notifications. +type notificationConfig struct { + // Max number of consecutive failures per channel + maxConsecutiveFailuresPerChannel uint32 +} + const defaultMaxOwnedSearchesPerUser = 25 const defaultMaxBookmarksPerUser = 25 const defaultBatchSize = 5000 const defaultBatchWriters = 8 +const defaultMaxConsecutiveFailuresPerChannel = 5 func combineAndDeduplicate(excluded []string, discouraged []string) []string { if excluded == nil && discouraged == nil { @@ -219,6 +227,9 @@ func NewSpannerClient(projectID string, instanceID string, name string) (*Client maxOwnedSearchesPerUser: defaultMaxOwnedSearchesPerUser, maxBookmarksPerUser: defaultMaxBookmarksPerUser, }, + notificationConfig{ + maxConsecutiveFailuresPerChannel: defaultMaxConsecutiveFailuresPerChannel, + }, bw, defaultBatchSize, defaultBatchWriters, diff --git a/lib/gcpspanner/notification_channel_delivery_attempt.go b/lib/gcpspanner/notification_channel_delivery_attempt.go index 4432a0e40..79484e49a 100644 --- a/lib/gcpspanner/notification_channel_delivery_attempt.go +++ b/lib/gcpspanner/notification_channel_delivery_attempt.go @@ -16,6 +16,7 @@ package gcpspanner import ( "context" + "encoding/json" "fmt" "time" @@ -25,13 +26,45 @@ import ( const notificationChannelDeliveryAttemptTable = "NotificationChannelDeliveryAttempts" const maxDeliveryAttemptsToKeep = 10 -// NotificationChannelDeliveryAttempt represents a row in the NotificationChannelDeliveryAttempt table. -type NotificationChannelDeliveryAttempt struct { +// spannerNotificationChannelDeliveryAttempt represents a row in the spannerNotificationChannelDeliveryAttempt table. +type spannerNotificationChannelDeliveryAttempt struct { ID string `spanner:"ID"` ChannelID string `spanner:"ChannelID"` AttemptTimestamp time.Time `spanner:"AttemptTimestamp"` Status NotificationChannelDeliveryAttemptStatus `spanner:"Status"` Details spanner.NullJSON `spanner:"Details"` + AttemptDetails *AttemptDetails `spanner:"-"` +} + +func (s spannerNotificationChannelDeliveryAttempt) toPublic() (*NotificationChannelDeliveryAttempt, error) { + var attemptDetails *AttemptDetails + if s.Details.Valid { + attemptDetails = new(AttemptDetails) + b, err := json.Marshal(s.Details.Value) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &attemptDetails) + if err != nil { + return nil, err + } + } + + return &NotificationChannelDeliveryAttempt{ + ID: s.ID, + ChannelID: s.ChannelID, + AttemptTimestamp: s.AttemptTimestamp, + Status: s.Status, + AttemptDetails: attemptDetails, + }, nil +} + +type NotificationChannelDeliveryAttempt struct { + ID string `spanner:"ID"` + ChannelID string `spanner:"ChannelID"` + AttemptTimestamp time.Time `spanner:"AttemptTimestamp"` + Status NotificationChannelDeliveryAttemptStatus `spanner:"Status"` + AttemptDetails *AttemptDetails `spanner:"AttemptDetails"` } type NotificationChannelDeliveryAttemptStatus string @@ -71,13 +104,14 @@ func (m notificationChannelDeliveryAttemptMapper) Table() string { func (m notificationChannelDeliveryAttemptMapper) NewEntity( id string, - req CreateNotificationChannelDeliveryAttemptRequest) (NotificationChannelDeliveryAttempt, error) { - return NotificationChannelDeliveryAttempt{ + req CreateNotificationChannelDeliveryAttemptRequest) (spannerNotificationChannelDeliveryAttempt, error) { + return spannerNotificationChannelDeliveryAttempt{ ID: id, ChannelID: req.ChannelID, AttemptTimestamp: req.AttemptTimestamp, Status: req.Status, Details: req.Details, + AttemptDetails: nil, }, nil } @@ -132,7 +166,8 @@ type notificationChannelDeliveryAttemptCursor struct { } // EncodePageToken returns the ID of the delivery attempt as a page token. -func (m notificationChannelDeliveryAttemptMapper) EncodePageToken(item NotificationChannelDeliveryAttempt) string { +func (m notificationChannelDeliveryAttemptMapper) EncodePageToken( + item spannerNotificationChannelDeliveryAttempt) string { return encodeCursor(notificationChannelDeliveryAttemptCursor{ LastID: item.ID, LastAttemptTimestamp: item.AttemptTimestamp, @@ -144,61 +179,65 @@ func (c *Client) CreateNotificationChannelDeliveryAttempt( ctx context.Context, req CreateNotificationChannelDeliveryAttemptRequest) (*string, error) { var newID *string _, err := c.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - // 1. Create the new attempt - id, err := newEntityCreator[notificationChannelDeliveryAttemptMapper](c).createWithTransaction(ctx, txn, req) - if err != nil { - return err - } - newID = id + var err error + newID, err = c.createNotificationChannelDeliveryAttemptWithTransaction(ctx, txn, req) + + return err + }) + + return newID, err +} +func (c *Client) createNotificationChannelDeliveryAttemptWithTransaction( + ctx context.Context, txn *spanner.ReadWriteTransaction, + req CreateNotificationChannelDeliveryAttemptRequest) (*string, error) { + var newID *string + // 1. Create the new attempt + id, err := newEntityCreator[notificationChannelDeliveryAttemptMapper](c).createWithTransaction(ctx, txn, req) + if err != nil { + return nil, err + } + newID = id - // 2. Count existing attempts for the channel. Note: This count does not include the new attempt just buffered. - countStmt := spanner.NewStatement(` + // 2. Count existing attempts for the channel. Note: This count does not include the new attempt just buffered. + countStmt := spanner.NewStatement(` SELECT COUNT(*) FROM NotificationChannelDeliveryAttempts WHERE ChannelID = @channelID`) - countStmt.Params["channelID"] = req.ChannelID - var count int64 - err = txn.Query(ctx, countStmt).Do(func(r *spanner.Row) error { - return r.Column(0, &count) - }) - if err != nil { - return err - } + countStmt.Params["channelID"] = req.ChannelID + var count int64 + err = txn.Query(ctx, countStmt).Do(func(r *spanner.Row) error { + return r.Column(0, &count) + }) + if err != nil { + return nil, err + } - // 3. If the pre-insert count is at the limit, fetch the oldest attempts to delete. - if count >= maxDeliveryAttemptsToKeep { - // We need to delete enough to make room for the one we are adding. - deleteCount := count - maxDeliveryAttemptsToKeep + 1 - deleteStmt := spanner.NewStatement(` + // 3. If the pre-insert count is at the limit, fetch the oldest attempts to delete. + // We need to delete enough to make room for the one we are adding. + + if count < maxDeliveryAttemptsToKeep { + return newID, nil + } + + deleteCount := count - maxDeliveryAttemptsToKeep + 1 + deleteStmt := spanner.NewStatement(` SELECT ID FROM NotificationChannelDeliveryAttempts WHERE ChannelID = @channelID ORDER BY AttemptTimestamp ASC LIMIT @deleteCount`) - deleteStmt.Params["channelID"] = req.ChannelID - deleteStmt.Params["deleteCount"] = deleteCount - - var mutations []*spanner.Mutation - err := txn.Query(ctx, deleteStmt).Do(func(r *spanner.Row) error { - var attemptID string - if err := r.Column(0, &attemptID); err != nil { - return err - } - mutations = append(mutations, - spanner.Delete(notificationChannelDeliveryAttemptTable, - spanner.Key{attemptID, req.ChannelID})) - - return nil - }) - if err != nil { - return err - } - - // 4. Buffer delete mutations - if len(mutations) > 0 { - return txn.BufferWrite(mutations) - } + deleteStmt.Params["channelID"] = req.ChannelID + deleteStmt.Params["deleteCount"] = deleteCount + + var mutations []*spanner.Mutation + err = txn.Query(ctx, deleteStmt).Do(func(r *spanner.Row) error { + var attemptID string + if err := r.Column(0, &attemptID); err != nil { + return err } + mutations = append(mutations, + spanner.Delete(notificationChannelDeliveryAttemptTable, + spanner.Key{attemptID, req.ChannelID})) return nil }) @@ -206,6 +245,14 @@ func (c *Client) CreateNotificationChannelDeliveryAttempt( return nil, err } + // 4. Buffer delete mutations + if len(mutations) > 0 { + err := txn.BufferWrite(mutations) + if err != nil { + return nil, err + } + } + return newID, nil } @@ -214,8 +261,13 @@ func (c *Client) GetNotificationChannelDeliveryAttempt( ctx context.Context, attemptID string, channelID string) (*NotificationChannelDeliveryAttempt, error) { key := deliveryAttemptKey{ID: attemptID, ChannelID: channelID} - return newEntityReader[notificationChannelDeliveryAttemptMapper, - NotificationChannelDeliveryAttempt, deliveryAttemptKey](c).readRowByKey(ctx, key) + attempt, err := newEntityReader[notificationChannelDeliveryAttemptMapper, + spannerNotificationChannelDeliveryAttempt, deliveryAttemptKey](c).readRowByKey(ctx, key) + if err != nil { + return nil, err + } + + return attempt.toPublic() } // ListNotificationChannelDeliveryAttempts lists all delivery attempts for a channel. @@ -223,5 +275,19 @@ func (c *Client) ListNotificationChannelDeliveryAttempts( ctx context.Context, req ListNotificationChannelDeliveryAttemptsRequest, ) ([]NotificationChannelDeliveryAttempt, *string, error) { - return newEntityLister[notificationChannelDeliveryAttemptMapper](c).list(ctx, req) + attempts, nextPageToken, err := newEntityLister[notificationChannelDeliveryAttemptMapper](c).list(ctx, req) + if err != nil { + return nil, nil, err + } + + publicAttempts := make([]NotificationChannelDeliveryAttempt, 0, len(attempts)) + for _, attempt := range attempts { + publicAttempt, err := attempt.toPublic() + if err != nil { + return nil, nil, err + } + publicAttempts = append(publicAttempts, *publicAttempt) + } + + return publicAttempts, nextPageToken, nil } diff --git a/lib/gcpspanner/notification_channel_delivery_attempt_test.go b/lib/gcpspanner/notification_channel_delivery_attempt_test.go index a536391a8..e34520e6a 100644 --- a/lib/gcpspanner/notification_channel_delivery_attempt_test.go +++ b/lib/gcpspanner/notification_channel_delivery_attempt_test.go @@ -43,9 +43,10 @@ func TestCreateNotificationChannelDeliveryAttempt(t *testing.T) { channelID := *channelIDPtr req := CreateNotificationChannelDeliveryAttemptRequest{ - ChannelID: channelID, - Status: "SUCCESS", - Details: spanner.NullJSON{Value: map[string]interface{}{"info": "delivered"}, Valid: true}, + ChannelID: channelID, + Status: "SUCCESS", + Details: spanner.NullJSON{Value: map[string]interface{}{ + "event_id": "evt-123", "message": "delivered"}, Valid: true}, AttemptTimestamp: time.Now(), } @@ -71,6 +72,15 @@ func TestCreateNotificationChannelDeliveryAttempt(t *testing.T) { if retrieved.AttemptTimestamp.IsZero() { t.Error("expected a non-zero commit timestamp") } + if retrieved.AttemptDetails == nil { + t.Fatal("expected details to be non-nil") + } + if retrieved.AttemptDetails.Message != "delivered" { + t.Errorf("expected details info to be 'delivered', got %s", retrieved.AttemptDetails.Message) + } + if retrieved.AttemptDetails.EventID != "evt-123" { + t.Errorf("expected details eventID to be 'evt-123', got %s", retrieved.AttemptDetails.EventID) + } } func TestCreateNotificationChannelDeliveryAttemptPruning(t *testing.T) { @@ -202,3 +212,126 @@ func TestCreateNotificationChannelDeliveryAttemptConcurrency(t *testing.T) { t.Errorf("expected %d attempts, got %d", maxDeliveryAttemptsToKeep, len(attempts)) } } + +func TestListNotificationChannelDeliveryAttemptsPagination(t *testing.T) { + ctx := context.Background() + restartDatabaseContainer(t) + // We need a channel to associate the attempt with. + userID := uuid.NewString() + createReq := CreateNotificationChannelRequest{ + UserID: userID, + Name: "Test Channel", + Type: "EMAIL", + EmailConfig: &EmailConfig{Address: "test@example.com", IsVerified: true, VerificationToken: nil}, + } + channelIDPtr, err := spannerClient.CreateNotificationChannel(ctx, createReq) + if err != nil { + t.Fatalf("failed to create notification channel: %v", err) + } + channelID := *channelIDPtr + + // Create more attempts than the page size to test pagination. + totalAttempts := 5 + for i := 0; i < totalAttempts; i++ { + // The sleep is a simple way to ensure distinct AttemptTimestamps for ordering. + time.Sleep(1 * time.Millisecond) + req := CreateNotificationChannelDeliveryAttemptRequest{ + ChannelID: channelID, + Status: "SUCCESS", + Details: spanner.NullJSON{Value: nil, Valid: false}, + AttemptTimestamp: time.Now(), + } + _, err := spannerClient.CreateNotificationChannelDeliveryAttempt(ctx, req) + if err != nil { + t.Fatalf("CreateNotificationChannelDeliveryAttempt (pagination test) failed: %v", err) + } + } + + // 1. First Page + pageSize := 2 + listReq1 := ListNotificationChannelDeliveryAttemptsRequest{ + ChannelID: channelID, + PageSize: pageSize, + PageToken: nil, + } + attempts1, nextToken1, err := spannerClient.ListNotificationChannelDeliveryAttempts(ctx, listReq1) + if err != nil { + t.Fatalf("ListNotificationChannelDeliveryAttempts (page 1) failed: %v", err) + } + if len(attempts1) != pageSize { + t.Errorf("expected %d attempts on page 1, got %d", pageSize, len(attempts1)) + } + if nextToken1 == nil { + t.Fatal("expected a next page token, but got nil") + } + + // 2. Second Page + listReq2 := ListNotificationChannelDeliveryAttemptsRequest{ + ChannelID: channelID, + PageSize: pageSize, + PageToken: nextToken1, + } + attempts2, nextToken2, err := spannerClient.ListNotificationChannelDeliveryAttempts(ctx, listReq2) + if err != nil { + t.Fatalf("ListNotificationChannelDeliveryAttempts (page 2) failed: %v", err) + } + if len(attempts2) != pageSize { + t.Errorf("expected %d attempts on page 2, got %d", pageSize, len(attempts2)) + } + if nextToken2 == nil { + t.Fatal("expected a next page token, but got nil") + } + + // 3. Third and Final Page + listReq3 := ListNotificationChannelDeliveryAttemptsRequest{ + ChannelID: channelID, + PageSize: pageSize, + PageToken: nextToken2, + } + attempts3, nextToken3, err := spannerClient.ListNotificationChannelDeliveryAttempts(ctx, listReq3) + if err != nil { + t.Fatalf("ListNotificationChannelDeliveryAttempts (page 3) failed: %v", err) + } + if len(attempts3) != 1 { + t.Errorf("expected 1 attempt on page 3, got %d", len(attempts3)) + } + if nextToken3 != nil { + t.Errorf("expected no next page token, but got one: %s", *nextToken3) + } +} + +func TestToPublic(t *testing.T) { + attempt := spannerNotificationChannelDeliveryAttempt{ + ID: "test-id", + ChannelID: "test-channel-id", + AttemptTimestamp: time.Now(), + Status: DeliveryAttemptStatusSuccess, + Details: spanner.NullJSON{Value: map[string]interface{}{"message": "test-info", + "event_id": "test-event-id"}, Valid: true}, + AttemptDetails: nil, + } + + publicAttempt, err := attempt.toPublic() + if err != nil { + t.Fatalf("toPublic() failed: %v", err) + } + + if publicAttempt.ID != attempt.ID { + t.Errorf("expected ID %s, got %s", attempt.ID, publicAttempt.ID) + } + if publicAttempt.ChannelID != attempt.ChannelID { + t.Errorf("expected ChannelID %s, got %s", attempt.ChannelID, publicAttempt.ChannelID) + } + if publicAttempt.Status != attempt.Status { + t.Errorf("expected Status %s, got %s", attempt.Status, publicAttempt.Status) + } + if publicAttempt.AttemptDetails == nil { + t.Fatal("expected AttemptDetails to be non-nil") + } + if publicAttempt.AttemptDetails.Message != "test-info" { + t.Errorf("expected AttemptDetails.Message %s, got %s", "test-info", publicAttempt.AttemptDetails.Message) + } + if publicAttempt.AttemptDetails.EventID != "test-event-id" { + t.Errorf("expected AttemptDetails.EventID %s, got %s", "test-event-id", publicAttempt.AttemptDetails.EventID) + } +} diff --git a/lib/gcpspanner/notification_channel_state.go b/lib/gcpspanner/notification_channel_state.go index 087af22b8..fe7e2351b 100644 --- a/lib/gcpspanner/notification_channel_state.go +++ b/lib/gcpspanner/notification_channel_state.go @@ -16,6 +16,7 @@ package gcpspanner import ( "context" + "errors" "fmt" "time" @@ -78,3 +79,105 @@ func (c *Client) GetNotificationChannelState( return newEntityReader[notificationChannelStateMapper, NotificationChannelState, string](c).readRowByKey(ctx, channelID) } + +// RecordNotificationChannelSuccess resets the consecutive failures count in the NotificationChannelStates table +// and logs a successful delivery attempt in the NotificationChannelDeliveryAttempts table. +func (c *Client) RecordNotificationChannelSuccess( + ctx context.Context, channelID string, timestamp time.Time, eventID string) error { + _, err := c.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + // Update NotificationChannelStates + err := newEntityWriter[notificationChannelStateMapper](c).upsertWithTransaction(ctx, txn, + NotificationChannelState{ + ChannelID: channelID, + IsDisabledBySystem: false, + ConsecutiveFailures: 0, + CreatedAt: timestamp, + UpdatedAt: timestamp, + }) + if err != nil { + return err + } + + _, err = c.createNotificationChannelDeliveryAttemptWithTransaction(ctx, txn, + CreateNotificationChannelDeliveryAttemptRequest{ + ChannelID: channelID, + AttemptTimestamp: timestamp, + Status: DeliveryAttemptStatusSuccess, + Details: spanner.NullJSON{Value: AttemptDetails{ + EventID: eventID, + Message: "delivered"}, Valid: true}, + }) + + return err + }) + + return err + +} + +// RecordNotificationChannelFailure increments the consecutive failures count in the NotificationChannelStates table +// and logs a failure delivery attempt in the NotificationChannelDeliveryAttempts table. +// If isPermanent is true, it increments the failure count and potentially disables the channel. +// If isPermanent is false (transient), it logs the error but does not penalize the channel health. +func (c *Client) RecordNotificationChannelFailure( + ctx context.Context, channelID string, errorMsg string, timestamp time.Time, + isPermanent bool, eventID string) error { + _, err := c.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + // Read current state + state, err := newEntityReader[notificationChannelStateMapper, NotificationChannelState, string](c). + readRowByKeyWithTransaction(ctx, channelID, txn) + if err != nil && !errors.Is(err, ErrQueryReturnedNoResults) { + return err + } else if errors.Is(err, ErrQueryReturnedNoResults) { + state = &NotificationChannelState{ + ChannelID: channelID, + CreatedAt: timestamp, + UpdatedAt: timestamp, + IsDisabledBySystem: false, + ConsecutiveFailures: 0, + } + } + + // Calculate new state + if isPermanent { + state.ConsecutiveFailures++ + } + state.UpdatedAt = timestamp + state.IsDisabledBySystem = state.ConsecutiveFailures >= int64( + c.notificationCfg.maxConsecutiveFailuresPerChannel) + + // Update NotificationChannelStates + err = newEntityWriter[notificationChannelStateMapper](c).upsertWithTransaction(ctx, + txn, NotificationChannelState{ + ChannelID: channelID, + IsDisabledBySystem: state.IsDisabledBySystem, + ConsecutiveFailures: state.ConsecutiveFailures, + CreatedAt: state.CreatedAt, + UpdatedAt: state.UpdatedAt, + }) + if err != nil { + return err + } + + // Log attempt + _, err = c.createNotificationChannelDeliveryAttemptWithTransaction(ctx, txn, + CreateNotificationChannelDeliveryAttemptRequest{ + ChannelID: channelID, + AttemptTimestamp: timestamp, + Status: DeliveryAttemptStatusFailure, + Details: spanner.NullJSON{Value: AttemptDetails{ + EventID: eventID, + Message: errorMsg}, Valid: true}, + }) + + return err + + }) + + return err +} + +type AttemptDetails struct { + Message string `json:"message"` + EventID string `json:"event_id"` +} diff --git a/lib/gcpspanner/notification_channel_state_test.go b/lib/gcpspanner/notification_channel_state_test.go index d9178d421..a9815a579 100644 --- a/lib/gcpspanner/notification_channel_state_test.go +++ b/lib/gcpspanner/notification_channel_state_test.go @@ -17,6 +17,7 @@ package gcpspanner import ( "context" "testing" + "time" "cloud.google.com/go/spanner" "github.com/google/go-cmp/cmp" @@ -115,4 +116,151 @@ func TestNotificationChannelStateOperations(t *testing.T) { t.Errorf("GetNotificationChannelState after update mismatch (-want +got):\n%s", diff) } }) + + t.Run("RecordNotificationChannelSuccess", func(t *testing.T) { + testRecordNotificationChannelSuccess(t, channelID) + }) + + t.Run("RecordNotificationChannelFailure", func(t *testing.T) { + testRecordNotificationChannelFailure(t, channelID) + }) +} + +func testRecordNotificationChannelSuccess(t *testing.T, channelID string) { + ctx := t.Context() + // First, set up a channel state with some failures. + initialState := &NotificationChannelState{ + ChannelID: channelID, + IsDisabledBySystem: true, + ConsecutiveFailures: 3, + CreatedAt: spanner.CommitTimestamp, + UpdatedAt: spanner.CommitTimestamp, + } + err := spannerClient.UpsertNotificationChannelState(ctx, *initialState) + if err != nil { + t.Fatalf("pre-test UpsertNotificationChannelState failed: %v", err) + } + + testTime := time.Now() + eventID := "evt-1" + err = spannerClient.RecordNotificationChannelSuccess(ctx, channelID, testTime, eventID) + if err != nil { + t.Fatalf("RecordNotificationChannelSuccess failed: %v", err) + } + + // Verify state update. + retrievedState, err := spannerClient.GetNotificationChannelState(ctx, channelID) + if err != nil { + t.Fatalf("GetNotificationChannelState after success failed: %v", err) + } + if retrievedState.IsDisabledBySystem != false { + t.Errorf("expected IsDisabledBySystem to be false, got %t", retrievedState.IsDisabledBySystem) + } + if retrievedState.ConsecutiveFailures != 0 { + t.Errorf("expected ConsecutiveFailures to be 0, got %d", retrievedState.ConsecutiveFailures) + } + + // Verify delivery attempt log. + listAttemptsReq := ListNotificationChannelDeliveryAttemptsRequest{ + ChannelID: channelID, + PageSize: 1, + PageToken: nil, + } + attempts, _, err := spannerClient.ListNotificationChannelDeliveryAttempts(ctx, listAttemptsReq) + if err != nil { + t.Fatalf("ListNotificationChannelDeliveryAttempts after success failed: %v", err) + } + if len(attempts) != 1 { + t.Fatalf("expected 1 delivery attempt, got %d", len(attempts)) + } + if attempts[0].Status != DeliveryAttemptStatusSuccess { + t.Errorf("expected status SUCCESS, got %s", attempts[0].Status) + } + if attempts[0].AttemptDetails == nil || attempts[0].AttemptDetails.Message != "delivered" || + attempts[0].AttemptDetails.EventID != "evt-1" { + t.Errorf("expected details message 'delivered', got %v", attempts[0].AttemptDetails) + } +} + +func testRecordNotificationChannelFailure(t *testing.T, channelID string) { + ctx := t.Context() + // Reset state for new test + initialState := &NotificationChannelState{ + ChannelID: channelID, + IsDisabledBySystem: false, + ConsecutiveFailures: 0, + CreatedAt: spanner.CommitTimestamp, + UpdatedAt: spanner.CommitTimestamp, + } + err := spannerClient.UpsertNotificationChannelState(ctx, *initialState) + if err != nil { + t.Fatalf("pre-test UpsertNotificationChannelState failed: %v", err) + } + + t.Run("Permanent Failure", func(t *testing.T) { + _ = spannerClient.UpsertNotificationChannelState(ctx, *initialState) // Ensure clean state + testTime := time.Now() + errorMsg := "permanent error" + eventID := "evt-124" + err = spannerClient.RecordNotificationChannelFailure(ctx, channelID, errorMsg, testTime, true, eventID) + if err != nil { + t.Fatalf("RecordNotificationChannelFailure (permanent) failed: %v", err) + } + + verifyFailureAttemptAndState(t, channelID, 1, false, errorMsg, eventID) + }) + + t.Run("Transient Failure", func(t *testing.T) { + _ = spannerClient.UpsertNotificationChannelState(ctx, *initialState) // Ensure clean state + testTime := time.Now() + errorMsg := "transient error" + eventID := "evt-125" + err = spannerClient.RecordNotificationChannelFailure(ctx, channelID, errorMsg, testTime, false, eventID) + if err != nil { + t.Fatalf("RecordNotificationChannelFailure (transient) failed: %v", err) + } + + verifyFailureAttemptAndState(t, channelID, 0, false, errorMsg, eventID) + }) +} + +// verifyFailureAttemptAndState is a helper function to verify the state and delivery attempt after a failure. +func verifyFailureAttemptAndState(t *testing.T, channelID string, + expectedFailures int64, expectedIsDisabled bool, expectedAttemptMessage string, expectedEventID string) { + t.Helper() + ctx := t.Context() + + // Verify state update. + retrievedState, err := spannerClient.GetNotificationChannelState(ctx, channelID) + if err != nil { + t.Fatalf("GetNotificationChannelState after failure failed: %v", err) + } + if retrievedState.ConsecutiveFailures != expectedFailures { + t.Errorf("expected ConsecutiveFailures to be %d, got %d", expectedFailures, retrievedState.ConsecutiveFailures) + } + if retrievedState.IsDisabledBySystem != expectedIsDisabled { + t.Errorf("expected IsDisabledBySystem to be %t, got %t", expectedIsDisabled, retrievedState.IsDisabledBySystem) + } + + // Verify delivery attempt log. + listAttemptsReq := ListNotificationChannelDeliveryAttemptsRequest{ + ChannelID: channelID, + PageSize: 1, + PageToken: nil, + } + attempts, _, err := spannerClient.ListNotificationChannelDeliveryAttempts(ctx, listAttemptsReq) + if err != nil { + t.Fatalf("ListNotificationChannelDeliveryAttempts after failure failed: %v", err) + } + if len(attempts) != 1 { + t.Fatalf("expected 1 delivery attempt, got %d", len(attempts)) + } + if attempts[0].Status != DeliveryAttemptStatusFailure { + t.Errorf("expected status FAILURE, got %s", attempts[0].Status) + } + if attempts[0].AttemptDetails == nil || attempts[0].AttemptDetails.Message != expectedAttemptMessage || + attempts[0].AttemptDetails.EventID != expectedEventID { + t.Errorf("expected details message '%s' eventID '%s', got %v", expectedAttemptMessage, expectedEventID, + attempts[0].AttemptDetails) + } }