From 0ffeeae1d53a868f01e5c1cdd8de89390bcf16ea Mon Sep 17 00:00:00 2001 From: francescopepe <3891780+francescopepe@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:35:11 +0200 Subject: [PATCH 1/2] feat: make the retriever stop immediately if the context is canceled --- go.mod | 8 +++++++- go.sum | 12 ++++++++++++ internal/client/client.go | 8 ++++++-- retriever.go | 8 +++++++- sqs.go | 4 ++-- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 4307c56..f22b95a 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,17 @@ module github.com/francescopepe/formigo go 1.21 -require github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4 +require ( + github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5 + github.com/stretchr/testify v1.9.0 +) require ( github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect github.com/aws/smithy-go v1.20.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3609e04..c419e11 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,17 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4 h1:FXPO72iKC5YmYNEANltl763bUj8A6qT20wx8Jwvxlsw= github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4/go.mod h1:7idt3XszF6sE9WPS1GqZRiDJOxw4oPtlRBXodWnCGjU= +github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5 h1:HYyVDOC2/PIg+3oBX1q0wtDU5kONki6lrgIG0afrBkY= +github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5/go.mod h1:7idt3XszF6sE9WPS1GqZRiDJOxw4oPtlRBXodWnCGjU= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/client/client.go b/internal/client/client.go index 9e24aeb..7b4c7b0 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -1,9 +1,13 @@ package client -import "github.com/francescopepe/formigo/internal/messages" +import ( + "context" + + "github.com/francescopepe/formigo/internal/messages" +) type MessageReceiver interface { - ReceiveMessages() ([]messages.Message, error) + ReceiveMessages(ctx context.Context) ([]messages.Message, error) } type MessageDeleter interface { diff --git a/retriever.go b/retriever.go index d0939fc..3765e74 100644 --- a/retriever.go +++ b/retriever.go @@ -17,8 +17,14 @@ func retriever(ctx context.Context, receiver client.MessageReceiver, ctrl *contr case <-ctx.Done(): return default: - msgs, err := receiver.ReceiveMessages() + msgs, err := receiver.ReceiveMessages(ctx) if err != nil { + if errors.Is(err, context.Canceled) && errors.Is(ctx.Err(), context.Canceled) { + // The worker's context was canceled. We can exit. + return + } + + // Report the error to the controller and continue. ctrl.reportError(fmt.Errorf("unable to receive message: %w", err)) continue } diff --git a/sqs.go b/sqs.go index 9beb547..1b09a30 100644 --- a/sqs.go +++ b/sqs.go @@ -35,8 +35,8 @@ type sqsClient struct { messageCtxTimeout time.Duration } -func (c sqsClient) ReceiveMessages() ([]messages.Message, error) { - out, err := c.svc.ReceiveMessage(context.Background(), c.receiveMessageInput) +func (c sqsClient) ReceiveMessages(ctx context.Context) ([]messages.Message, error) { + out, err := c.svc.ReceiveMessage(ctx, c.receiveMessageInput) if err != nil { return nil, fmt.Errorf("unable to receive messages: %w", err) } From 01021115ac15f8cf3edef45b4a8b92df033e42bf Mon Sep 17 00:00:00 2001 From: francescopepe <3891780+francescopepe@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:35:37 +0200 Subject: [PATCH 2/2] ci: add some simple tests and workflow --- .github/workflows/test.yml | 19 +++ worker_test.go | 250 +++++++++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 .github/workflows/test.yml create mode 100644 worker_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..21744eb --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,19 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + cache: false + - name: Run tests + run: go test ./... diff --git a/worker_test.go b/worker_test.go new file mode 100644 index 0000000..7731880 --- /dev/null +++ b/worker_test.go @@ -0,0 +1,250 @@ +package formigo + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/francescopepe/formigo/internal/messages" + "github.com/stretchr/testify/assert" +) + +type SimpleInMemoryBrokerMessage struct { + messageId string + body string + deleteReqCh chan struct{} + deleteAckCh chan struct{} + timer *time.Timer +} + +type SimpleInMemoryBroker struct { + visibilityTimeout time.Duration + queue chan *SimpleInMemoryBrokerMessage + inFlights chan *SimpleInMemoryBrokerMessage + expired chan *SimpleInMemoryBrokerMessage + + statics struct { + rwMutex sync.RWMutex + enqueuedMessages int + inFlightMessages int + } +} + +func NewSimpleInMemoryBroker(visibilityTimeout time.Duration) *SimpleInMemoryBroker { + return &SimpleInMemoryBroker{ + visibilityTimeout: visibilityTimeout, + queue: make(chan *SimpleInMemoryBrokerMessage, 1000), + inFlights: make(chan *SimpleInMemoryBrokerMessage), + expired: make(chan *SimpleInMemoryBrokerMessage, 1000), + } +} + +func (b *SimpleInMemoryBroker) run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg := <-b.inFlights: + go func(ctx context.Context) { + select { + case <-ctx.Done(): + return + case <-msg.deleteReqCh: + msg.deleteAckCh <- struct{}{} + case <-msg.timer.C: + b.expired <- msg + } + }(ctx) + } + } +} + +func (b *SimpleInMemoryBroker) AddMessages(msgs []*SimpleInMemoryBrokerMessage) { + for _, msg := range msgs { + b.queue <- msg + b.statics.rwMutex.Lock() + b.statics.enqueuedMessages++ + b.statics.rwMutex.Unlock() + } +} + +func (b *SimpleInMemoryBroker) DeleteMessages(msgs []messages.Message) error { + requestTimer := time.NewTimer(time.Second * 5) + defer requestTimer.Stop() + + for _, msg := range msgs { + brokerMsg := msg.Content().(*SimpleInMemoryBrokerMessage) + + select { + case <-requestTimer.C: + return fmt.Errorf("failed to delete message %s: request timeout", brokerMsg.messageId) + case brokerMsg.deleteReqCh <- struct{}{}: + } + + if !brokerMsg.timer.Stop() { + return fmt.Errorf("failed to delete message %s: visibility timeout exipired", brokerMsg.messageId) + } + + <-brokerMsg.deleteAckCh + + b.statics.rwMutex.Lock() + b.statics.inFlightMessages-- + b.statics.rwMutex.Unlock() + } + + return nil +} + +func (b *SimpleInMemoryBroker) ReceiveMessages(ctx context.Context) ([]messages.Message, error) { + var polledMessage *SimpleInMemoryBrokerMessage + select { + case polledMessage = <-b.expired: + default: + timer := time.NewTimer(time.Millisecond * 500) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, nil + case polledMessage = <-b.expired: + case polledMessage = <-b.queue: + } + } + + polledMessage.timer = time.NewTimer(b.visibilityTimeout) + polledMessage.deleteReqCh = make(chan struct{}) + polledMessage.deleteAckCh = make(chan struct{}) + + time.After(time.Millisecond * 5) + + msg := messages.Message{ + MsgId: polledMessage.messageId, + Msg: polledMessage, + ReceivedTime: time.Now(), + } + + // Set a context with timeout + msg.Ctx, msg.CancelCtx = context.WithTimeout(context.Background(), b.visibilityTimeout) + + // Move the message to inflight + b.inFlights <- polledMessage + b.statics.rwMutex.Lock() + b.statics.enqueuedMessages-- + b.statics.inFlightMessages++ + b.statics.rwMutex.Unlock() + + return []messages.Message{msg}, nil +} + +func (b *SimpleInMemoryBroker) EnqueuedMessages() int { + b.statics.rwMutex.RLock() + defer b.statics.rwMutex.RUnlock() + return b.statics.enqueuedMessages +} + +func (b *SimpleInMemoryBroker) InFlightMessages() int { + b.statics.rwMutex.RLock() + defer b.statics.rwMutex.RUnlock() + return b.statics.inFlightMessages +} + +func TestWorker(t *testing.T) { + inMemoryBroker := NewSimpleInMemoryBroker(time.Second * 10) + go inMemoryBroker.run(context.Background()) + + t.Run("can receive a message", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + msgs := []*SimpleInMemoryBrokerMessage{ + { + messageId: "1", + body: "Hello, world!", + }, + } + + inMemoryBroker.AddMessages(msgs) + + wkr := NewWorker(Configuration{ + Client: inMemoryBroker, + Concurrency: 1, + Retrievers: 1, + ErrorConfig: ErrorConfiguration{ + ReportFunc: func(err error) bool { + t.Fatalf("unexpected error: %v", err) + return true + }, + }, + Consumer: NewMessageConsumer(MessageConsumerConfiguration{ + Handler: func(ctx context.Context, msg Message) error { + defer cancel() + + assert.Equal(t, "Hello, world!", msg.Content().(*SimpleInMemoryBrokerMessage).body) + + return nil + }, + }), + }) + + assert.NoError(t, wkr.Run(ctx)) + assert.Equal(t, 0, inMemoryBroker.EnqueuedMessages()) + assert.Equal(t, 0, inMemoryBroker.InFlightMessages()) + }) + + t.Run("can receive a batch of messages", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + msgs := []*SimpleInMemoryBrokerMessage{ + { + messageId: "1", + body: "Hello, world 1!", + }, + { + messageId: "2", + body: "Hello, world 2!", + }, + { + messageId: "3", + body: "Hello, world 3!", + }, + } + + inMemoryBroker.AddMessages(msgs) + + wkr := NewWorker(Configuration{ + Client: inMemoryBroker, + Concurrency: 1, + Retrievers: 1, + ErrorConfig: ErrorConfiguration{ + ReportFunc: func(err error) bool { + t.Fatalf("unexpected error: %v", err) + return true + }, + }, + Consumer: NewBatchConsumer(BatchConsumerConfiguration{ + BufferConfig: BatchConsumerBufferConfiguration{ + Size: 3, + Timeout: time.Second, + }, + Handler: func(ctx context.Context, msgs []Message) (BatchResponse, error) { + defer cancel() + + if len(msgs) < 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + assert.Equal(t, "Hello, world 1!", msgs[0].Content().(*SimpleInMemoryBrokerMessage).body) + assert.Equal(t, "Hello, world 2!", msgs[1].Content().(*SimpleInMemoryBrokerMessage).body) + assert.Equal(t, "Hello, world 3!", msgs[2].Content().(*SimpleInMemoryBrokerMessage).body) + + return BatchResponse{}, nil + }, + }), + }) + + assert.NoError(t, wkr.Run(ctx)) + assert.Equal(t, 0, inMemoryBroker.EnqueuedMessages()) + assert.Equal(t, 0, inMemoryBroker.InFlightMessages()) + }) +}