diff --git a/common/persistence/cassandra/matching_task_store_user_data.go b/common/persistence/cassandra/matching_task_store_user_data.go index efa1fb96eb..c61bab8b31 100644 --- a/common/persistence/cassandra/matching_task_store_user_data.go +++ b/common/persistence/cassandra/matching_task_store_user_data.go @@ -11,6 +11,7 @@ import ( const ( // Not much of a need to make this configurable, we're just reading some strings listTaskQueueNamesByBuildIdPageSize = 100 + maxTaskQueuesByBuildIdResults = 10000 templateUpdateTaskQueueUserDataQuery = `UPDATE task_queue_user_data SET data = ?, @@ -172,7 +173,7 @@ func (d *userDataStore) ListTaskQueueUserDataEntries(ctx context.Context, reques } func (d *userDataStore) GetTaskQueuesByBuildId(ctx context.Context, request *p.GetTaskQueuesByBuildIdRequest) ([]string, error) { - query := d.Session.Query(templateListTaskQueueNamesByBuildIdQuery, request.NamespaceID, request.BuildID).WithContext(ctx) + query := d.Session.Query(templateListTaskQueueNamesByBuildIdQuery, request.NamespaceID, request.BuildID).WithContext(ctx).Idempotent(true) iter := query.PageSize(listTaskQueueNamesByBuildIdPageSize).Iter() var taskQueues []string @@ -192,6 +193,11 @@ func (d *userDataStore) GetTaskQueuesByBuildId(ctx context.Context, request *p.G taskQueues = append(taskQueues, taskQueue) + if len(taskQueues) >= maxTaskQueuesByBuildIdResults { + _ = iter.Close() + return taskQueues, nil + } + row = make(map[string]any) // Reinitialize map as initialized fails on unmarshalling } if len(iter.PageState()) == 0 { diff --git a/common/persistence/cassandra/queue_store.go b/common/persistence/cassandra/queue_store.go index 0a85830eb1..c0da705224 100644 --- a/common/persistence/cassandra/queue_store.go +++ b/common/persistence/cassandra/queue_store.go @@ -135,7 +135,7 @@ func (q *QueueStore) ReadMessages( maxCount, ).WithContext(ctx) - iter := query.Iter() + iter := query.PageSize(maxCount).Iter() var result []*persistence.QueueMessage message := make(map[string]any) diff --git a/common/persistence/cassandra/queue_v2_store.go b/common/persistence/cassandra/queue_v2_store.go index c5f968fd44..028bdd6e77 100644 --- a/common/persistence/cassandra/queue_v2_store.go +++ b/common/persistence/cassandra/queue_v2_store.go @@ -29,6 +29,12 @@ type ( ) const ( + // maxListQueuesPages limits the number of CQL round-trips in ListQueues. With + // ALLOW FILTERING, Cassandra may return under-filled pages, so we may need + // multiple fetches. This cap prevents unbounded queries when most partitions + // don't match the queue_type filter. + maxListQueuesPages = 10 + TemplateEnqueueMessageQuery = `INSERT INTO queue_messages (queue_type, queue_name, queue_partition, message_id, message_payload, message_encoding) VALUES (?, ?, ?, ?, ?, ?) IF NOT EXISTS` TemplateGetMessagesQuery = `SELECT message_id, message_payload, message_encoding FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? ORDER BY message_id ASC LIMIT ?` TemplateGetMaxMessageIDQuery = `SELECT message_id FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? ORDER BY message_id DESC LIMIT 1` @@ -36,8 +42,8 @@ const ( TemplateGetQueueQuery = `SELECT metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? AND queue_name = ?` TemplateRangeDeleteMessagesQuery = `DELETE FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? AND message_id <= ?` TemplateUpdateQueueMetadataQuery = `UPDATE queues SET metadata_payload = ?, metadata_encoding = ?, version = ? WHERE queue_type = ? AND queue_name = ? IF version = ?` - // We will have to ALLOW FILTERING for this query since partition key consists of both queue_type and queue_name. - templateGetQueueNamesQuery = `SELECT queue_name, metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? ALLOW FILTERING` + // TemplateGetQueueNamesQuery uses ALLOW FILTERING since the partition key consists of both queue_type and queue_name. + TemplateGetQueueNamesQuery = `SELECT queue_name, metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? ALLOW FILTERING` ) var ( @@ -160,7 +166,7 @@ func (s *queueV2Store) ReadMessages( 0, int(minMessageID), request.PageSize, - ).WithContext(ctx).Iter() + ).WithContext(ctx).PageSize(request.PageSize).Iter() var ( messages []persistence.QueueV2Message @@ -471,45 +477,67 @@ func (s *queueV2Store) ListQueues( if request.PageSize <= 0 { return nil, persistence.ErrNonPositiveListQueuesPageSize } - iter := s.session.Query( - templateGetQueueNamesQuery, - request.QueueType, - ).PageSize(request.PageSize).PageState(request.NextPageToken).WithContext(ctx).Iter() - var queues []persistence.QueueInfo - for { - var ( - queueName string - metadataBytes []byte - metadataEncoding string - version int64 - ) - if !iter.Scan(&queueName, &metadataBytes, &metadataEncoding, &version) { - break + queues := make([]persistence.QueueInfo, 0, request.PageSize) + pageToken := request.NextPageToken + + // CQL queries with ALLOW FILTERING may return under-filled pages because Cassandra + // scans a fixed number of partitions per page and then post-filters. We loop over + // CQL pages until we have enough results or exhaust all pages, with an upper bound + // on round-trips to avoid unbounded queries when most partitions don't match. + for pages := 0; len(queues) < request.PageSize && pages < maxListQueuesPages; pages++ { + iter := s.session.Query( + TemplateGetQueueNamesQuery, + request.QueueType, + ).PageSize(request.PageSize).PageState(pageToken).WithContext(ctx).Iter() + + for { + var ( + queueName string + metadataBytes []byte + metadataEncoding string + version int64 + ) + if !iter.Scan(&queueName, &metadataBytes, &metadataEncoding, &version) { + break + } + q, err := getQueueFromMetadata(request.QueueType, queueName, metadataBytes, metadataEncoding, version) + if err != nil { + _ = iter.Close() + return nil, err + } + partition, err := persistence.GetPartitionForQueueV2(request.QueueType, queueName, q.Metadata) + if err != nil { + _ = iter.Close() + return nil, err + } + messageCount, lastMessageID, err := s.getMessageCountAndLastID(ctx, request.QueueType, queueName, partition) + if err != nil { + _ = iter.Close() + return nil, err + } + queues = append(queues, persistence.QueueInfo{ + QueueName: queueName, + MessageCount: messageCount, + LastMessageID: lastMessageID, + }) } - q, err := getQueueFromMetadata(request.QueueType, queueName, metadataBytes, metadataEncoding, version) - if err != nil { - return nil, err + + if len(iter.PageState()) > 0 { + pageToken = iter.PageState() + } else { + pageToken = nil } - partition, err := persistence.GetPartitionForQueueV2(request.QueueType, queueName, q.Metadata) - if err != nil { - return nil, err + if err := iter.Close(); err != nil { + return nil, gocql.ConvertError("QueueV2ListQueues", err) } - messageCount, lastMessageID, err := s.getMessageCountAndLastID(ctx, request.QueueType, queueName, partition) - if err != nil { - return nil, err + if pageToken == nil { + break } - queues = append(queues, persistence.QueueInfo{ - QueueName: queueName, - MessageCount: messageCount, - LastMessageID: lastMessageID, - }) - } - if err := iter.Close(); err != nil { - return nil, gocql.ConvertError("QueueV2ListQueues", err) } + return &persistence.InternalListQueuesResponse{ Queues: queues, - NextPageToken: iter.PageState(), + NextPageToken: pageToken, }, nil } diff --git a/common/persistence/tests/cassandra_test.go b/common/persistence/tests/cassandra_test.go index f8b38b14c3..5221b07370 100644 --- a/common/persistence/tests/cassandra_test.go +++ b/common/persistence/tests/cassandra_test.go @@ -98,6 +98,14 @@ func (q failingQuery) Iter() gocql.Iter { return failingIter{} } +func (q failingQuery) PageSize(int) gocql.Query { + return q +} + +func (q failingQuery) PageState([]byte) gocql.Query { + return q +} + func (q failingQuery) Scan(...any) error { return assert.AnError } @@ -428,6 +436,10 @@ func testCassandraQueueV2QueryErrors(t *testing.T, cluster *cassandra.TestCluste t.Parallel() testCassandraQueueV2ErrListQueuesGetMaxMessageIDQuery(t, cluster) }) + t.Run("ListQueuesGetQueueNamesQuery", func(t *testing.T) { + t.Parallel() + testCassandraQueueV2ErrListQueuesGetQueueNamesQuery(t, cluster) + }) t.Run("RangeDeleteMessagesGetMaxMessageIDQuery", func(t *testing.T) { t.Parallel() testCassandraQueueV2ErrRangeDeleteMessagesGetMaxMessageIDQuery(t, cluster) @@ -525,6 +537,23 @@ func testCassandraQueueV2ErrListQueuesGetMaxMessageIDQuery(t *testing.T, cluster assert.ErrorContains(t, err, "QueueV2GetMaxMessageID") } +func testCassandraQueueV2ErrListQueuesGetQueueNamesQuery(t *testing.T, cluster *cassandra.TestCluster) { + q := newQueueV2Store(failingSession{ + Session: cluster.GetSession(), + failingQueries: []string{cassandra.TemplateGetQueueNamesQuery}, + }) + ctx := context.Background() + queueType := persistence.QueueTypeHistoryDLQ + _, err := q.ListQueues(ctx, &persistence.InternalListQueuesRequest{ + QueueType: queueType, + PageSize: 100, + }) + require.Error(t, err) + assert.ErrorAs(t, err, new(*serviceerror.Unavailable)) + assert.ErrorContains(t, err, assert.AnError.Error()) + assert.ErrorContains(t, err, "QueueV2ListQueues") +} + func testCassandraQueueV2MultiplePartitions(t *testing.T, cluster *cassandra.TestCluster) { t.Run("RangeDeleteMessages", func(t *testing.T) { t.Parallel()