Skip to content

Commit

Permalink
Merge pull request #15 from francescopepe/FP/stop-retriever-immediately
Browse files Browse the repository at this point in the history
Make the retriever stop immediately and add some tests
  • Loading branch information
francescopepe authored Aug 27, 2024
2 parents cc77bc3 + 0102111 commit 3ead0c6
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 6 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Tests

on:
push:
branches:
- main
pull_request:

jobs:
tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
cache: false
- name: Run tests
run: go test ./...
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@ module github.com/francescopepe/formigo

go 1.21

require github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4
require (
github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5
github.com/stretchr/testify v1.9.0
)

require (
github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,17 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs=
github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4 h1:FXPO72iKC5YmYNEANltl763bUj8A6qT20wx8Jwvxlsw=
github.com/aws/aws-sdk-go-v2/service/sqs v1.34.4/go.mod h1:7idt3XszF6sE9WPS1GqZRiDJOxw4oPtlRBXodWnCGjU=
github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5 h1:HYyVDOC2/PIg+3oBX1q0wtDU5kONki6lrgIG0afrBkY=
github.com/aws/aws-sdk-go-v2/service/sqs v1.34.5/go.mod h1:7idt3XszF6sE9WPS1GqZRiDJOxw4oPtlRBXodWnCGjU=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
8 changes: 6 additions & 2 deletions internal/client/client.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package client

import "github.com/francescopepe/formigo/internal/messages"
import (
"context"

"github.com/francescopepe/formigo/internal/messages"
)

type MessageReceiver interface {
ReceiveMessages() ([]messages.Message, error)
ReceiveMessages(ctx context.Context) ([]messages.Message, error)
}

type MessageDeleter interface {
Expand Down
8 changes: 7 additions & 1 deletion retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ func retriever(ctx context.Context, receiver client.MessageReceiver, ctrl *contr
case <-ctx.Done():
return
default:
msgs, err := receiver.ReceiveMessages()
msgs, err := receiver.ReceiveMessages(ctx)
if err != nil {
if errors.Is(err, context.Canceled) && errors.Is(ctx.Err(), context.Canceled) {
// The worker's context was canceled. We can exit.
return
}

// Report the error to the controller and continue.
ctrl.reportError(fmt.Errorf("unable to receive message: %w", err))
continue
}
Expand Down
4 changes: 2 additions & 2 deletions sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ type sqsClient struct {
messageCtxTimeout time.Duration
}

func (c sqsClient) ReceiveMessages() ([]messages.Message, error) {
out, err := c.svc.ReceiveMessage(context.Background(), c.receiveMessageInput)
func (c sqsClient) ReceiveMessages(ctx context.Context) ([]messages.Message, error) {
out, err := c.svc.ReceiveMessage(ctx, c.receiveMessageInput)
if err != nil {
return nil, fmt.Errorf("unable to receive messages: %w", err)
}
Expand Down
250 changes: 250 additions & 0 deletions worker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package formigo

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/francescopepe/formigo/internal/messages"
"github.com/stretchr/testify/assert"
)

type SimpleInMemoryBrokerMessage struct {
messageId string
body string
deleteReqCh chan struct{}
deleteAckCh chan struct{}
timer *time.Timer
}

type SimpleInMemoryBroker struct {
visibilityTimeout time.Duration
queue chan *SimpleInMemoryBrokerMessage
inFlights chan *SimpleInMemoryBrokerMessage
expired chan *SimpleInMemoryBrokerMessage

statics struct {
rwMutex sync.RWMutex
enqueuedMessages int
inFlightMessages int
}
}

func NewSimpleInMemoryBroker(visibilityTimeout time.Duration) *SimpleInMemoryBroker {
return &SimpleInMemoryBroker{
visibilityTimeout: visibilityTimeout,
queue: make(chan *SimpleInMemoryBrokerMessage, 1000),
inFlights: make(chan *SimpleInMemoryBrokerMessage),
expired: make(chan *SimpleInMemoryBrokerMessage, 1000),
}
}

func (b *SimpleInMemoryBroker) run(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case msg := <-b.inFlights:
go func(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-msg.deleteReqCh:
msg.deleteAckCh <- struct{}{}
case <-msg.timer.C:
b.expired <- msg
}
}(ctx)
}
}
}

func (b *SimpleInMemoryBroker) AddMessages(msgs []*SimpleInMemoryBrokerMessage) {
for _, msg := range msgs {
b.queue <- msg
b.statics.rwMutex.Lock()
b.statics.enqueuedMessages++
b.statics.rwMutex.Unlock()
}
}

func (b *SimpleInMemoryBroker) DeleteMessages(msgs []messages.Message) error {
requestTimer := time.NewTimer(time.Second * 5)
defer requestTimer.Stop()

for _, msg := range msgs {
brokerMsg := msg.Content().(*SimpleInMemoryBrokerMessage)

select {
case <-requestTimer.C:
return fmt.Errorf("failed to delete message %s: request timeout", brokerMsg.messageId)
case brokerMsg.deleteReqCh <- struct{}{}:
}

if !brokerMsg.timer.Stop() {
return fmt.Errorf("failed to delete message %s: visibility timeout exipired", brokerMsg.messageId)
}

<-brokerMsg.deleteAckCh

b.statics.rwMutex.Lock()
b.statics.inFlightMessages--
b.statics.rwMutex.Unlock()
}

return nil
}

func (b *SimpleInMemoryBroker) ReceiveMessages(ctx context.Context) ([]messages.Message, error) {
var polledMessage *SimpleInMemoryBrokerMessage
select {
case polledMessage = <-b.expired:
default:
timer := time.NewTimer(time.Millisecond * 500)
defer timer.Stop()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return nil, nil
case polledMessage = <-b.expired:
case polledMessage = <-b.queue:
}
}

polledMessage.timer = time.NewTimer(b.visibilityTimeout)
polledMessage.deleteReqCh = make(chan struct{})
polledMessage.deleteAckCh = make(chan struct{})

time.After(time.Millisecond * 5)

msg := messages.Message{
MsgId: polledMessage.messageId,
Msg: polledMessage,
ReceivedTime: time.Now(),
}

// Set a context with timeout
msg.Ctx, msg.CancelCtx = context.WithTimeout(context.Background(), b.visibilityTimeout)

// Move the message to inflight
b.inFlights <- polledMessage
b.statics.rwMutex.Lock()
b.statics.enqueuedMessages--
b.statics.inFlightMessages++
b.statics.rwMutex.Unlock()

return []messages.Message{msg}, nil
}

func (b *SimpleInMemoryBroker) EnqueuedMessages() int {
b.statics.rwMutex.RLock()
defer b.statics.rwMutex.RUnlock()
return b.statics.enqueuedMessages
}

func (b *SimpleInMemoryBroker) InFlightMessages() int {
b.statics.rwMutex.RLock()
defer b.statics.rwMutex.RUnlock()
return b.statics.inFlightMessages
}

func TestWorker(t *testing.T) {
inMemoryBroker := NewSimpleInMemoryBroker(time.Second * 10)
go inMemoryBroker.run(context.Background())

t.Run("can receive a message", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
msgs := []*SimpleInMemoryBrokerMessage{
{
messageId: "1",
body: "Hello, world!",
},
}

inMemoryBroker.AddMessages(msgs)

wkr := NewWorker(Configuration{
Client: inMemoryBroker,
Concurrency: 1,
Retrievers: 1,
ErrorConfig: ErrorConfiguration{
ReportFunc: func(err error) bool {
t.Fatalf("unexpected error: %v", err)
return true
},
},
Consumer: NewMessageConsumer(MessageConsumerConfiguration{
Handler: func(ctx context.Context, msg Message) error {
defer cancel()

assert.Equal(t, "Hello, world!", msg.Content().(*SimpleInMemoryBrokerMessage).body)

return nil
},
}),
})

assert.NoError(t, wkr.Run(ctx))
assert.Equal(t, 0, inMemoryBroker.EnqueuedMessages())
assert.Equal(t, 0, inMemoryBroker.InFlightMessages())
})

t.Run("can receive a batch of messages", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
msgs := []*SimpleInMemoryBrokerMessage{
{
messageId: "1",
body: "Hello, world 1!",
},
{
messageId: "2",
body: "Hello, world 2!",
},
{
messageId: "3",
body: "Hello, world 3!",
},
}

inMemoryBroker.AddMessages(msgs)

wkr := NewWorker(Configuration{
Client: inMemoryBroker,
Concurrency: 1,
Retrievers: 1,
ErrorConfig: ErrorConfiguration{
ReportFunc: func(err error) bool {
t.Fatalf("unexpected error: %v", err)
return true
},
},
Consumer: NewBatchConsumer(BatchConsumerConfiguration{
BufferConfig: BatchConsumerBufferConfiguration{
Size: 3,
Timeout: time.Second,
},
Handler: func(ctx context.Context, msgs []Message) (BatchResponse, error) {
defer cancel()

if len(msgs) < 3 {
t.Fatalf("expected 3 messages, got %d", len(msgs))
}

assert.Equal(t, "Hello, world 1!", msgs[0].Content().(*SimpleInMemoryBrokerMessage).body)
assert.Equal(t, "Hello, world 2!", msgs[1].Content().(*SimpleInMemoryBrokerMessage).body)
assert.Equal(t, "Hello, world 3!", msgs[2].Content().(*SimpleInMemoryBrokerMessage).body)

return BatchResponse{}, nil
},
}),
})

assert.NoError(t, wkr.Run(ctx))
assert.Equal(t, 0, inMemoryBroker.EnqueuedMessages())
assert.Equal(t, 0, inMemoryBroker.InFlightMessages())
})
}

0 comments on commit 3ead0c6

Please sign in to comment.