diff --git a/cmd/broker/main.go b/cmd/broker/main.go index 61928ae..865aa36 100644 --- a/cmd/broker/main.go +++ b/cmd/broker/main.go @@ -39,6 +39,7 @@ import ( "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" "github.com/KafScale/platform/pkg/storage" + "golang.org/x/sync/semaphore" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -54,8 +55,9 @@ const ( defaultMinioRegion = "us-east-1" defaultMinioEndpoint = "http://127.0.0.1:9000" defaultMinioAccessKey = "minioadmin" - defaultMinioSecretKey = "minioadmin" - brokerVersion = "dev" + defaultMinioSecretKey = "minioadmin" + defaultS3Concurrency = 64 + brokerVersion = "dev" ) type handler struct { @@ -91,6 +93,7 @@ type handler struct { authMetrics *authMetrics authLogMu sync.Mutex authLogLast map[string]time.Time + s3sem *semaphore.Weighted } type etcdAvailability interface { @@ -1971,7 +1974,7 @@ func (h *handler) getPartitionLog(ctx context.Context, topic string, partition i if err := h.store.UpdateOffsets(cbCtx, topic, partition, artifact.LastOffset); err != nil { h.logger.Error("update offsets failed", "error", err, "topic", topic, "partition", partition) } - }, h.recordS3Op) + }, h.recordS3Op, h.s3sem) lastOffset, err := plog.RestoreFromS3(ctx) if err != nil { h.logger.Error("restore partition log from S3 failed", "topic", topic, "partition", partition, "error", err) @@ -2007,6 +2010,12 @@ func newHandler(store metadata.Store, s3Client storage.S3Client, brokerInfo prot segmentBytes := parseEnvInt("KAFSCALE_SEGMENT_BYTES", 4<<20) flushInterval := time.Duration(parseEnvInt("KAFSCALE_FLUSH_INTERVAL_MS", 500)) * time.Millisecond flushOnAck := parseEnvBool("KAFSCALE_PRODUCE_SYNC_FLUSH", true) + // 0 or negative disables the S3 concurrency limit (no semaphore, default HTTP pool). + s3Concurrency := parseEnvInt("KAFSCALE_S3_CONCURRENCY", defaultS3Concurrency) + var s3sem *semaphore.Weighted + if s3Concurrency > 0 { + s3sem = semaphore.NewWeighted(int64(s3Concurrency)) + } produceLatencyBuckets := []float64{1, 2, 5, 10, 25, 50, 100, 250, 500, 1000, 2000, 5000} consumerLagBuckets := []float64{1, 10, 100, 1000, 5000, 10000, 50000, 100000, 500000, 1000000} if autoPartitions < 1 { @@ -2055,6 +2064,7 @@ func newHandler(store metadata.Store, s3Client storage.S3Client, brokerInfo prot authorizer: authorizer, authMetrics: newAuthMetrics(), authLogLast: make(map[string]time.Time), + s3sem: s3sem, } } @@ -2184,6 +2194,7 @@ func buildS3ConfigsFromEnv() (storage.S3Config, storage.S3Config, bool, bool, bo secretKey = defaultMinioSecretKey } credsProvided := accessKey != "" && secretKey != "" + s3Concurrency := parseEnvInt("KAFSCALE_S3_CONCURRENCY", defaultS3Concurrency) writeCfg := storage.S3Config{ Bucket: writeBucket, Region: writeRegion, @@ -2193,6 +2204,7 @@ func buildS3ConfigsFromEnv() (storage.S3Config, storage.S3Config, bool, bool, bo SecretAccessKey: secretKey, SessionToken: sessionToken, KMSKeyARN: kmsARN, + MaxConnections: s3Concurrency, } readBucket := os.Getenv("KAFSCALE_S3_READ_BUCKET") @@ -2217,6 +2229,7 @@ func buildS3ConfigsFromEnv() (storage.S3Config, storage.S3Config, bool, bool, bo SecretAccessKey: secretKey, SessionToken: sessionToken, KMSKeyARN: kmsARN, + MaxConnections: s3Concurrency, } return writeCfg, readCfg, false, usingDefaultMinio, credsProvided, useReadReplica } diff --git a/docs/operations.md b/docs/operations.md index 13c607a..83abcb3 100644 --- a/docs/operations.md +++ b/docs/operations.md @@ -295,6 +295,7 @@ Recommended operator alerting (when using Prometheus Operator): - `KAFSCALE_S3_PATH_STYLE` – Force path-style addressing (`true/false`). - `KAFSCALE_S3_KMS_ARN` – KMS key ARN for SSE-KMS. - `KAFSCALE_S3_ACCESS_KEY`, `KAFSCALE_S3_SECRET_KEY`, `KAFSCALE_S3_SESSION_TOKEN` – S3 credentials. +- `KAFSCALE_S3_CONCURRENCY` – Broker-wide cap on concurrent S3 operations (default `64`, `0` to disable). Lower for slower S3-compatible backends. Read replica example (multi-region reads): diff --git a/pkg/storage/log.go b/pkg/storage/log.go index bc966c7..bf44b55 100644 --- a/pkg/storage/log.go +++ b/pkg/storage/log.go @@ -28,6 +28,7 @@ import ( "github.com/KafScale/platform/pkg/cache" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) // PartitionLogConfig configures per-partition log behavior. @@ -54,6 +55,9 @@ type PartitionLog struct { indexEntries map[int64][]*IndexEntry prefetchMu sync.Mutex mu sync.Mutex + flushCond *sync.Cond + s3sem *semaphore.Weighted + flushing bool } type segmentRange struct { @@ -66,11 +70,11 @@ type segmentRange struct { var ErrOffsetOutOfRange = errors.New("offset out of range") // NewPartitionLog constructs a log for a topic partition. -func NewPartitionLog(namespace string, topic string, partition int32, startOffset int64, s3Client S3Client, cache *cache.SegmentCache, cfg PartitionLogConfig, onFlush func(context.Context, *SegmentArtifact), onS3Op func(string, time.Duration, error)) *PartitionLog { +func NewPartitionLog(namespace string, topic string, partition int32, startOffset int64, s3Client S3Client, cache *cache.SegmentCache, cfg PartitionLogConfig, onFlush func(context.Context, *SegmentArtifact), onS3Op func(string, time.Duration, error), sem *semaphore.Weighted) *PartitionLog { if namespace == "" { namespace = "default" } - return &PartitionLog{ + pl := &PartitionLog{ namespace: namespace, topic: topic, partition: partition, @@ -83,13 +87,42 @@ func NewPartitionLog(namespace string, topic string, partition int32, startOffse onS3Op: onS3Op, segments: make([]segmentRange, 0), indexEntries: make(map[int64][]*IndexEntry), + s3sem: sem, + } + pl.flushCond = sync.NewCond(&pl.mu) + return pl +} + +// acquireS3 blocks until a semaphore token is available or ctx is cancelled. +// If no semaphore is configured, it returns immediately. +func (l *PartitionLog) acquireS3(ctx context.Context) error { + if l.s3sem == nil { + return nil + } + return l.s3sem.Acquire(ctx, 1) +} + +func (l *PartitionLog) tryAcquireS3() bool { + if l.s3sem == nil { + return true + } + return l.s3sem.TryAcquire(1) +} + +func (l *PartitionLog) releaseS3() { + if l.s3sem != nil { + l.s3sem.Release(1) } } // RestoreFromS3 rebuilds segment ranges from objects already stored in S3. func (l *PartitionLog) RestoreFromS3(ctx context.Context) (int64, error) { prefix := l.segmentPrefix() + if err := l.acquireS3(ctx); err != nil { + return -1, err + } objects, err := l.s3.ListSegments(ctx, prefix) + l.releaseS3() if err != nil { return -1, err } @@ -112,8 +145,12 @@ func (l *PartitionLog) RestoreFromS3(ctx context.Context) (int64, error) { } start := obj.Size - segmentFooterLen rng := &ByteRange{Start: start, End: obj.Size - 1} + if err := l.acquireS3(ctx); err != nil { + return -1, err + } startTime := time.Now() footerBytes, err := l.s3.DownloadSegment(ctx, obj.Key, rng) + l.releaseS3() if l.onS3Op != nil { l.onS3Op("download_segment_footer", time.Since(startTime), err) } @@ -136,8 +173,12 @@ func (l *PartitionLog) RestoreFromS3(ctx context.Context) (int64, error) { indexByBase := make(map[int64][]*IndexEntry, len(entries)) for _, entry := range entries { indexKey := l.indexKey(entry.base) + if err := l.acquireS3(ctx); err != nil { + return -1, err + } startTime := time.Now() indexBytes, err := l.s3.DownloadIndex(ctx, indexKey) + l.releaseS3() if l.onS3Op != nil { l.onS3Op("download_index", time.Since(startTime), err) } @@ -170,8 +211,6 @@ func (l *PartitionLog) RestoreFromS3(ctx context.Context) (int64, error) { // AppendBatch writes a record batch to the log, updating offsets and flushing as needed. func (l *PartitionLog) AppendBatch(ctx context.Context, batch RecordBatch) (*AppendResult, error) { - var flushed *SegmentArtifact - l.mu.Lock() baseOffset := l.nextOffset PatchRecordBatchBaseOffset(&batch, baseOffset) @@ -182,9 +221,11 @@ func (l *PartitionLog) AppendBatch(ctx context.Context, batch RecordBatch) (*App BaseOffset: baseOffset, LastOffset: l.nextOffset - 1, } + + var artifact *SegmentArtifact if l.buffer.ShouldFlush(time.Now()) { var err error - flushed, err = l.flushLocked(ctx) + artifact, err = l.prepareFlush() if err != nil { l.mu.Unlock() return nil, err @@ -192,8 +233,13 @@ func (l *PartitionLog) AppendBatch(ctx context.Context, batch RecordBatch) (*App } l.mu.Unlock() - if flushed != nil && l.onFlush != nil { - l.onFlush(ctx, flushed) + if artifact != nil { + if err := l.uploadFlush(ctx, artifact); err != nil { + return nil, err + } + if l.onFlush != nil { + l.onFlush(ctx, artifact) + } } return result, nil } @@ -209,17 +255,34 @@ func (l *PartitionLog) EarliestOffset() int64 { } // Flush forces buffered batches to be written to S3 immediately. +// If another flush is already in progress on this partition, Flush waits for +// it to complete and then flushes any data that accumulated in the meantime. func (l *PartitionLog) Flush(ctx context.Context) error { l.mu.Lock() - artifact, err := l.flushLocked(ctx) + for l.flushing { + if ctx.Err() != nil { + l.mu.Unlock() + return ctx.Err() + } + l.flushCond.Wait() + } + artifact, err := l.prepareFlush() l.mu.Unlock() if err != nil { return err } + + if artifact != nil { + if err := l.uploadFlush(ctx, artifact); err != nil { + return err + } + } if l.onFlush != nil { target := artifact if target == nil { + l.mu.Lock() current := l.nextOffset - 1 + l.mu.Unlock() if current >= 0 { target = &SegmentArtifact{LastOffset: current} } @@ -231,7 +294,15 @@ func (l *PartitionLog) Flush(ctx context.Context) error { return nil } -func (l *PartitionLog) flushLocked(ctx context.Context) (*SegmentArtifact, error) { +// prepareFlush drains the buffer and builds a segment artifact under l.mu. +// It sets l.flushing = true to prevent concurrent flushes on the same +// partition. Returns (nil, nil) if the buffer is empty or a flush is already +// in progress (the data stays in the buffer for the next flush). +// Caller must hold l.mu. +func (l *PartitionLog) prepareFlush() (*SegmentArtifact, error) { + if l.flushing { + return nil, nil + } batches := l.buffer.Drain() if len(batches) == 0 { return nil, nil @@ -240,11 +311,22 @@ func (l *PartitionLog) flushLocked(ctx context.Context) (*SegmentArtifact, error if err != nil { return nil, fmt.Errorf("build segment: %w", err) } + l.flushing = true + return artifact, nil +} + +// uploadFlush uploads the segment and index to S3 (with semaphore gating), +// then re-acquires l.mu to commit segment metadata. Called without l.mu held. +func (l *PartitionLog) uploadFlush(ctx context.Context, artifact *SegmentArtifact) error { segmentKey := l.segmentKey(artifact.BaseOffset) indexKey := l.indexKey(artifact.BaseOffset) g, gctx := errgroup.WithContext(ctx) g.Go(func() error { + if err := l.acquireS3(gctx); err != nil { + return err + } + defer l.releaseS3() start := time.Now() err := l.s3.UploadSegment(gctx, segmentKey, artifact.SegmentBytes) if l.onS3Op != nil { @@ -253,6 +335,10 @@ func (l *PartitionLog) flushLocked(ctx context.Context) (*SegmentArtifact, error return err }) g.Go(func() error { + if err := l.acquireS3(gctx); err != nil { + return err + } + defer l.releaseS3() start := time.Now() err := l.s3.UploadIndex(gctx, indexKey, artifact.IndexBytes) if l.onS3Op != nil { @@ -261,11 +347,18 @@ func (l *PartitionLog) flushLocked(ctx context.Context) (*SegmentArtifact, error return err }) if err := g.Wait(); err != nil { - return nil, err + l.mu.Lock() + l.flushing = false + l.flushCond.Broadcast() + l.mu.Unlock() + return err } + if l.cache != nil && l.cfg.CacheEnabled { l.cache.SetSegment(l.cacheTopicKey(), l.partition, artifact.BaseOffset, artifact.SegmentBytes) } + + l.mu.Lock() l.segments = append(l.segments, segmentRange{ baseOffset: artifact.BaseOffset, lastOffset: artifact.LastOffset, @@ -274,8 +367,13 @@ func (l *PartitionLog) flushLocked(ctx context.Context) (*SegmentArtifact, error if artifact.RelativeIndex != nil { l.indexEntries[artifact.BaseOffset] = artifact.RelativeIndex } - l.startPrefetch(ctx, len(l.segments)-1) - return artifact, nil + l.flushing = false + l.flushCond.Broadcast() + lastSegIdx := len(l.segments) - 1 + l.mu.Unlock() + + l.startPrefetch(ctx, lastSegIdx) + return nil } func (l *PartitionLog) segmentKey(baseOffset int64) string { @@ -321,11 +419,13 @@ func (l *PartitionLog) Read(ctx context.Context, offset int64, maxBytes int32) ( l.mu.Lock() var seg segmentRange found := false + segIdx := -1 var entries []*IndexEntry - for _, s := range l.segments { + for i, s := range l.segments { if offset >= s.baseOffset && offset <= s.lastOffset { seg = s found = true + segIdx = i entries = l.indexEntries[s.baseOffset] break } @@ -343,9 +443,13 @@ func (l *PartitionLog) Read(ctx context.Context, offset int64, maxBytes int32) ( } rangeReadUsed := false if !ok { + if err := l.acquireS3(ctx); err != nil { + return nil, err + } if rangeRead, rng := l.segmentRangeForOffset(seg, entries, offset, maxBytes); rangeRead { start := time.Now() bytes, err := l.s3.DownloadSegment(ctx, l.segmentKey(seg.baseOffset), rng) + l.releaseS3() if l.onS3Op != nil { l.onS3Op("download_segment_range", time.Since(start), err) } @@ -357,6 +461,7 @@ func (l *PartitionLog) Read(ctx context.Context, offset int64, maxBytes int32) ( } else { start := time.Now() bytes, err := l.s3.DownloadSegment(ctx, l.segmentKey(seg.baseOffset), nil) + l.releaseS3() if l.onS3Op != nil { l.onS3Op("download_segment", time.Since(start), err) } @@ -369,7 +474,7 @@ func (l *PartitionLog) Read(ctx context.Context, offset int64, maxBytes int32) ( } } } - l.startPrefetch(ctx, l.segmentIndex(seg.baseOffset)+1) + l.startPrefetch(ctx, segIdx+1) if ok { body, err := l.sliceCachedSegment(seg, entries, offset, maxBytes, data) @@ -388,31 +493,36 @@ func (l *PartitionLog) Read(ctx context.Context, offset int64, maxBytes int32) ( return body, nil } -func (l *PartitionLog) segmentIndex(baseOffset int64) int { - for i, seg := range l.segments { - if seg.baseOffset == baseOffset { - return i - } - } - return -1 -} - func (l *PartitionLog) startPrefetch(ctx context.Context, nextIndex int) { if l.cfg.ReadAheadSegments <= 0 || nextIndex < 0 || l.cache == nil || !l.cfg.CacheEnabled { return } l.prefetchMu.Lock() defer l.prefetchMu.Unlock() + + l.mu.Lock() + segsLen := len(l.segments) + // Collect the segments we need to prefetch while holding the lock. + var toFetch []segmentRange for i := 0; i < l.cfg.ReadAheadSegments; i++ { idx := nextIndex + i - if idx >= len(l.segments) { + if idx >= segsLen { break } seg := l.segments[idx] if _, ok := l.cache.GetSegment(l.cacheTopicKey(), l.partition, seg.baseOffset); ok { continue } + toFetch = append(toFetch, seg) + } + l.mu.Unlock() + + for _, seg := range toFetch { go func(seg segmentRange) { + if !l.tryAcquireS3() { + return // semaphore full, skip prefetch + } + defer l.releaseS3() data, err := l.s3.DownloadSegment(ctx, l.segmentKey(seg.baseOffset), nil) if err != nil { return diff --git a/pkg/storage/log_test.go b/pkg/storage/log_test.go index 28bf7ad..d20dbef 100644 --- a/pkg/storage/log_test.go +++ b/pkg/storage/log_test.go @@ -18,11 +18,14 @@ package storage import ( "context" "encoding/binary" + "fmt" + "sync" "sync/atomic" "testing" "time" "github.com/KafScale/platform/pkg/cache" + "golang.org/x/sync/semaphore" ) func TestPartitionLogAppendFlush(t *testing.T) { @@ -39,7 +42,7 @@ func TestPartitionLogAppendFlush(t *testing.T) { }, }, func(ctx context.Context, artifact *SegmentArtifact) { flushCount++ - }, nil) + }, nil, nil) batchData := make([]byte, 70) batch, err := NewRecordBatchFromBytes(batchData) @@ -77,7 +80,7 @@ func TestPartitionLogRead(t *testing.T) { }, ReadAheadSegments: 1, CacheEnabled: true, - }, nil, nil) + }, nil, nil, nil) batchData := make([]byte, 70) batch, _ := NewRecordBatchFromBytes(batchData) @@ -107,7 +110,7 @@ func TestPartitionLogReadUsesIndexRange(t *testing.T) { Segment: SegmentWriterConfig{ IndexIntervalMessages: 1, }, - }, nil, nil) + }, nil, nil, nil) batch1 := makeBatchBytes(0, 0, 1, 0x11) batch2 := makeBatchBytes(1, 0, 1, 0x22) @@ -155,7 +158,7 @@ func TestPartitionLogReportsS3Uploads(t *testing.T) { }, }, nil, func(op string, d time.Duration, err error) { uploads.Add(1) - }) + }, nil) batchData := make([]byte, 70) batch, _ := NewRecordBatchFromBytes(batchData) @@ -182,7 +185,7 @@ func TestPartitionLogRestoreFromS3(t *testing.T) { Segment: SegmentWriterConfig{ IndexIntervalMessages: 1, }, - }, nil, nil) + }, nil, nil, nil) batchData := make([]byte, 70) batch, _ := NewRecordBatchFromBytes(batchData) @@ -202,7 +205,7 @@ func TestPartitionLogRestoreFromS3(t *testing.T) { Segment: SegmentWriterConfig{ IndexIntervalMessages: 1, }, - }, nil, nil) + }, nil, nil, nil) lastOffset, err := recovered.RestoreFromS3(context.Background()) if err != nil { t.Fatalf("RestoreFromS3: %v", err) @@ -226,6 +229,186 @@ func TestPartitionLogRestoreFromS3(t *testing.T) { } } +func TestPartitionLogPrefetchSkippedWhenSemaphoreFull(t *testing.T) { + s3mem := NewMemoryS3Client() + // First, write two segments without semaphore constraint. + writer := NewPartitionLog("default", "orders", 0, 0, s3mem, nil, PartitionLogConfig{ + Buffer: WriteBufferConfig{ + MaxBytes: 1, + FlushInterval: time.Millisecond, + }, + Segment: SegmentWriterConfig{ + IndexIntervalMessages: 1, + }, + }, nil, nil, nil) + for i := 0; i < 2; i++ { + batch, _ := NewRecordBatchFromBytes(makeBatchBytes(int64(i), 0, 1, byte(i+1))) + if _, err := writer.AppendBatch(context.Background(), batch); err != nil { + t.Fatalf("AppendBatch %d: %v", i, err) + } + time.Sleep(2 * time.Millisecond) + if err := writer.Flush(context.Background()); err != nil { + t.Fatalf("Flush %d: %v", i, err) + } + } + + // Now create a reader with a full semaphore and a fresh cache. + sem := semaphore.NewWeighted(1) + sem.Acquire(context.Background(), 1) // exhaust the semaphore + c := cache.NewSegmentCache(1 << 20) + reader := NewPartitionLog("default", "orders", 0, 0, s3mem, c, PartitionLogConfig{ + Buffer: WriteBufferConfig{ + MaxBytes: 1, + FlushInterval: time.Millisecond, + }, + Segment: SegmentWriterConfig{ + IndexIntervalMessages: 1, + }, + ReadAheadSegments: 2, + CacheEnabled: true, + }, nil, nil, sem) + // Restore segments so the reader knows about them. + sem.Release(1) // temporarily release for RestoreFromS3 + if _, err := reader.RestoreFromS3(context.Background()); err != nil { + t.Fatalf("RestoreFromS3: %v", err) + } + sem.Acquire(context.Background(), 1) // re-exhaust + + // Trigger prefetch — should be skipped because TryAcquire fails. + reader.startPrefetch(context.Background(), 0) + time.Sleep(5 * time.Millisecond) + + if _, ok := c.GetSegment("default/orders", 0, 0); ok { + t.Fatalf("expected prefetch to be skipped when semaphore is full") + } + sem.Release(1) +} + +func TestPartitionLogFlushWaitsForInflight(t *testing.T) { + // Verify that a concurrent Flush blocks until an in-flight flush completes + // and then flushes any data that accumulated in the meantime. + uploadStarted := make(chan struct{}) + uploadRelease := make(chan struct{}) + slowS3 := &slowUploadS3{ + MemoryS3Client: NewMemoryS3Client(), + started: uploadStarted, + release: uploadRelease, + } + flushCh := make(chan int64, 4) + log := NewPartitionLog("default", "orders", 0, 0, slowS3, nil, PartitionLogConfig{ + Buffer: WriteBufferConfig{ + MaxBytes: 1 << 20, // large — manual flush only + }, + Segment: SegmentWriterConfig{ + IndexIntervalMessages: 1, + }, + }, func(_ context.Context, a *SegmentArtifact) { + flushCh <- a.LastOffset + }, nil, nil) + + // Append batch 0 and start a slow flush in the background. + batch0, _ := NewRecordBatchFromBytes(makeBatchBytes(0, 0, 1, 0x01)) + if _, err := log.AppendBatch(context.Background(), batch0); err != nil { + t.Fatalf("AppendBatch 0: %v", err) + } + errCh := make(chan error, 1) + go func() { + errCh <- log.Flush(context.Background()) + }() + <-uploadStarted // first flush is now blocked in UploadSegment + + // Append batch 1 while the first flush is in progress. + batch1, _ := NewRecordBatchFromBytes(makeBatchBytes(1, 0, 1, 0x02)) + if _, err := log.AppendBatch(context.Background(), batch1); err != nil { + t.Fatalf("AppendBatch 1: %v", err) + } + + // Start second Flush — it must block until first completes. + errCh2 := make(chan error, 1) + go func() { + errCh2 <- log.Flush(context.Background()) + }() + time.Sleep(5 * time.Millisecond) // give second Flush time to reach the wait + + // Release the slow upload — both flushes should complete. + close(uploadRelease) + + if err := <-errCh; err != nil { + t.Fatalf("first Flush: %v", err) + } + if err := <-errCh2; err != nil { + t.Fatalf("second Flush: %v", err) + } + + // Both batches must have been flushed (two onFlush callbacks). + // Collect results from the channel now that both goroutines have returned. + close(flushCh) + var flushedOffsets []int64 + for off := range flushCh { + flushedOffsets = append(flushedOffsets, off) + } + if len(flushedOffsets) < 2 { + t.Fatalf("expected 2 flush callbacks, got %d", len(flushedOffsets)) + } +} + +// slowUploadS3 blocks UploadSegment until release is closed. +type slowUploadS3 struct { + *MemoryS3Client + started chan struct{} + release chan struct{} + once sync.Once +} + +func (s *slowUploadS3) UploadSegment(ctx context.Context, key string, body []byte) error { + s.once.Do(func() { close(s.started) }) + <-s.release + return s.MemoryS3Client.UploadSegment(ctx, key, body) +} + +func TestPartitionLogFlushErrorClearsFlushing(t *testing.T) { + failingS3 := &failingUploadS3{MemoryS3Client: NewMemoryS3Client()} + log := NewPartitionLog("default", "orders", 0, 0, failingS3, nil, PartitionLogConfig{ + Buffer: WriteBufferConfig{ + MaxBytes: 1 << 20, // large threshold so AppendBatch won't auto-flush + }, + Segment: SegmentWriterConfig{ + IndexIntervalMessages: 1, + }, + }, nil, nil, nil) + + batchData := make([]byte, 70) + batch, _ := NewRecordBatchFromBytes(batchData) + if _, err := log.AppendBatch(context.Background(), batch); err != nil { + t.Fatalf("AppendBatch: %v", err) + } + + err := log.Flush(context.Background()) + if err == nil { + t.Fatalf("expected flush to fail") + } + + // flushing flag should be cleared so a subsequent flush can proceed. + log.mu.Lock() + if log.flushing { + t.Fatalf("expected flushing flag to be false after failed upload") + } + log.mu.Unlock() +} + +// failingUploadS3 wraps MemoryS3Client but fails all uploads. +type failingUploadS3 struct { + *MemoryS3Client +} + +func (f *failingUploadS3) UploadSegment(ctx context.Context, key string, body []byte) error { + return fmt.Errorf("simulated S3 upload failure") +} + +func (f *failingUploadS3) UploadIndex(ctx context.Context, key string, body []byte) error { + return fmt.Errorf("simulated S3 upload failure") +} + func makeBatchBytes(baseOffset int64, lastOffsetDelta int32, messageCount int32, marker byte) []byte { const size = 70 data := make([]byte, size) diff --git a/pkg/storage/s3_aws.go b/pkg/storage/s3_aws.go index 1cb0555..4124100 100644 --- a/pkg/storage/s3_aws.go +++ b/pkg/storage/s3_aws.go @@ -21,8 +21,10 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -81,6 +83,12 @@ func NewS3Client(ctx context.Context, cfg S3Config) (S3Client, error) { client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { o.UsePathStyle = cfg.ForcePathStyle + if cfg.MaxConnections > 0 { + o.HTTPClient = awshttp.NewBuildableClient().WithTransportOptions(func(t *http.Transport) { + t.MaxIdleConnsPerHost = cfg.MaxConnections + t.MaxConnsPerHost = cfg.MaxConnections + }) + } }) return newAWSClientWithAPI(cfg.Bucket, cfg.Region, cfg.KMSKeyARN, client), nil diff --git a/pkg/storage/s3client.go b/pkg/storage/s3client.go index f35df84..f5aefa3 100644 --- a/pkg/storage/s3client.go +++ b/pkg/storage/s3client.go @@ -60,4 +60,5 @@ type S3Config struct { SecretAccessKey string SessionToken string KMSKeyARN string + MaxConnections int }