diff --git a/messaging/natsjscm/natsjscm.go b/messaging/natsjscm/natsjscm.go index 9571bf2..7aadb26 100644 --- a/messaging/natsjscm/natsjscm.go +++ b/messaging/natsjscm/natsjscm.go @@ -8,7 +8,6 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" - "github.com/sanity-io/litter" "github.com/simiancreative/simiango/logger" ) @@ -152,8 +151,6 @@ func (cm *ConnectionManager) Connect() error { return nil } - litter.Dump(cm.config) - // Connect to NATS nc, err := nats.Connect(cm.config.URL, cm.config.Options...) if err != nil { diff --git a/messaging/natsjsdlq/natsjsdlq.go b/messaging/natsjsdlq/natsjsdlq.go index ab4f469..04aad28 100644 --- a/messaging/natsjsdlq/natsjsdlq.go +++ b/messaging/natsjsdlq/natsjsdlq.go @@ -1,17 +1,16 @@ package natsjsdlq import ( + "context" "fmt" "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/simiancreative/simiango/messaging/natsjscm" + "github.com/simiancreative/simiango/messaging/natsjspub" ) -type JetStreamContext interface { - AddStream(cfg *nats.StreamConfig, opts ...nats.JSOpt) (*nats.StreamInfo, error) - PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) -} - type Msg interface { Metadata() (*nats.MsgMetadata, error) } @@ -28,20 +27,26 @@ type Config struct { MaxDeliveries int // Storage type for the DLQ stream - Storage nats.StorageType + Storage jetstream.StorageType // Optional handler for DLQ errors ErrorHandler func(error) + + // Context for the DLQ handler + Context context.Context } type Dependencies struct { - JetStream JetStreamContext + ConnectionManager natsjscm.Connector + Publisher natsjspub.Publisher } // Handler manages dead letter queue operations type Handler struct { config Config - js JetStreamContext + cm natsjscm.Connector + p natsjspub.Publisher + ctx context.Context } // NewHandler creates a new DLQ handler @@ -52,7 +57,9 @@ func NewHandler(deps Dependencies, config Config) (*Handler, error) { handler := &Handler{ config: config, - js: deps.JetStream, + ctx: config.Context, + cm: deps.ConnectionManager, + p: deps.Publisher, } if err := handler.setup(); err != nil { @@ -63,8 +70,12 @@ func NewHandler(deps Dependencies, config Config) (*Handler, error) { } func validateConfig(deps Dependencies, config Config) error { - if deps.JetStream == nil { - return fmt.Errorf("JetStream context is required") + if deps.ConnectionManager == nil { + return fmt.Errorf("connection manager is required") + } + + if deps.Publisher == nil { + return fmt.Errorf("publisher is required") } if config.StreamName == "" { @@ -80,7 +91,11 @@ func validateConfig(deps Dependencies, config Config) error { } if config.Storage == 0 { - config.Storage = nats.FileStorage + config.Storage = jetstream.FileStorage + } + + if config.Context == nil { + config.Context = context.Background() } return nil @@ -88,14 +103,14 @@ func validateConfig(deps Dependencies, config Config) error { // setup ensures the DLQ stream exists func (h *Handler) setup() error { - streamConfig := &nats.StreamConfig{ + streamConfig := jetstream.StreamConfig{ Name: h.config.StreamName, Subjects: []string{h.config.Subject}, Storage: h.config.Storage, - Retention: nats.WorkQueuePolicy, + Retention: jetstream.WorkQueuePolicy, } - _, err := h.js.AddStream(streamConfig) + _, err := h.cm.EnsureStream(h.ctx, streamConfig) if err != nil && err != nats.ErrStreamNameAlreadyInUse { return fmt.Errorf("failed to create DLQ stream: %w", err) } @@ -126,7 +141,7 @@ func (h *Handler) PublishMessage(msg *nats.Msg, reason string) error { dlqMsg.Data = msg.Data // Publish to DLQ - _, err := h.js.PublishMsg(dlqMsg) + _, err := h.p.Publish(h.ctx, dlqMsg) if err != nil && h.config.ErrorHandler != nil { h.config.ErrorHandler(fmt.Errorf("failed to publish to DLQ: %w", err)) } diff --git a/messaging/natsjsdlq/natsjsdlq_test.go b/messaging/natsjsdlq/natsjsdlq_test.go index 71ce596..0ed3e40 100644 --- a/messaging/natsjsdlq/natsjsdlq_test.go +++ b/messaging/natsjsdlq/natsjsdlq_test.go @@ -1,282 +1,438 @@ package natsjsdlq_test import ( + "context" "errors" "testing" + "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" "github.com/simiancreative/simiango/messaging/natsjsdlq" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/tj/assert" ) -// MockJetStreamContext is a test double for nats.JetStreamContext -type MockJetStreamContext struct { - AddStreamFunc func(*nats.StreamConfig, ...nats.JSOpt) (*nats.StreamInfo, error) - PublishMsgFunc func(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) - publishCalls []*nats.Msg +// Mock implementations +type MockConnectionManager struct { + mock.Mock } -func (m *MockJetStreamContext) AddStream( - cfg *nats.StreamConfig, - opts ...nats.JSOpt, -) (*nats.StreamInfo, error) { - if m.AddStreamFunc != nil { - return m.AddStreamFunc(cfg, opts...) +func (m *MockConnectionManager) Connect() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockConnectionManager) GetConnection() *nats.Conn { + args := m.Called() + return args.Get(0).(*nats.Conn) +} + +func (m *MockConnectionManager) GetJetStream() jetstream.JetStream { + args := m.Called() + return args.Get(0).(jetstream.JetStream) +} + +func (m *MockConnectionManager) EnsureStream( + ctx context.Context, + config jetstream.StreamConfig, +) (jetstream.JetStream, error) { + args := m.Called(ctx, config) + if args.Get(0) == nil { + return nil, args.Error(1) } - return &nats.StreamInfo{Config: *cfg}, nil + return args.Get(0).(jetstream.JetStream), args.Error(1) +} + +func (m *MockConnectionManager) Disconnect() error { + args := m.Called() + return args.Error(0) } -func (m *MockJetStreamContext) PublishMsg( - msg *nats.Msg, - opts ...nats.PubOpt, -) (*nats.PubAck, error) { - m.publishCalls = append(m.publishCalls, msg) - if m.PublishMsgFunc != nil { - return m.PublishMsgFunc(msg, opts...) +func (m *MockConnectionManager) IsConnected() bool { + args := m.Called() + return args.Bool(0) +} + +type MockPublisher struct { + mock.Mock +} + +func (m *MockPublisher) Publish(ctx context.Context, msg *nats.Msg) (*jetstream.PubAck, error) { + args := m.Called(ctx, msg) + if args.Get(0) == nil { + return nil, args.Error(1) } - return &nats.PubAck{}, nil + return args.Get(0).(*jetstream.PubAck), args.Error(1) +} + +type MockMsg struct { + mock.Mock } -func TestNewHandler(t *testing.T) { - tests := []struct { - name string - deps natsjsdlq.Dependencies - config natsjsdlq.Config - wantErr bool - errMsg string +func (m *MockMsg) Metadata() (*nats.MsgMetadata, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*nats.MsgMetadata), args.Error(1) +} + +// Test Suite +type DLQHandlerTestSuite struct { + mockCM *MockConnectionManager + mockPub *MockPublisher + ctx context.Context +} + +func (suite *DLQHandlerTestSuite) SetupTest() { + suite.mockCM = new(MockConnectionManager) + suite.mockPub = new(MockPublisher) + suite.ctx = context.Background() +} + +func TestNewHandlerValidation(t *testing.T) { + suite := new(DLQHandlerTestSuite) + suite.SetupTest() + + testCases := []struct { + name string + deps natsjsdlq.Dependencies + config natsjsdlq.Config + expectedErr string }{ { - name: "valid configuration", + name: "Valid configuration", + deps: natsjsdlq.Dependencies{ + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, + }, + config: natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "test.dlq", + MaxDeliveries: 3, + Storage: jetstream.FileStorage, + Context: suite.ctx, + }, + expectedErr: "", + }, + { + name: "Missing connection manager", + deps: natsjsdlq.Dependencies{ + ConnectionManager: nil, + Publisher: suite.mockPub, + }, + config: natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "test.dlq", + MaxDeliveries: 3, + }, + expectedErr: "invalid DLQ configuration: connection manager is required", + }, + { + name: "Missing publisher", deps: natsjsdlq.Dependencies{ - JetStream: &MockJetStreamContext{}, + ConnectionManager: suite.mockCM, + Publisher: nil, }, config: natsjsdlq.Config{ - StreamName: "test_dlq", + StreamName: "test-dlq", Subject: "test.dlq", MaxDeliveries: 3, - Storage: nats.FileStorage, }, - wantErr: false, + expectedErr: "invalid DLQ configuration: publisher is required", }, { - name: "missing jetstream", + name: "Empty stream name", deps: natsjsdlq.Dependencies{ - JetStream: nil, + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, }, config: natsjsdlq.Config{ - StreamName: "test_dlq", - Subject: "test.dlq", + StreamName: "", + Subject: "test.dlq", + MaxDeliveries: 3, }, - wantErr: true, - errMsg: "JetStream context is required", + expectedErr: "invalid DLQ configuration: stream name is required", }, { - name: "missing required config", + name: "Empty subject", deps: natsjsdlq.Dependencies{ - JetStream: &MockJetStreamContext{}, + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, }, - config: natsjsdlq.Config{}, - wantErr: true, - errMsg: "stream name is required", + config: natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "", + MaxDeliveries: 3, + }, + expectedErr: "invalid DLQ configuration: subject is required", + }, + { + name: "Invalid max deliveries", + deps: natsjsdlq.Dependencies{ + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, + }, + config: natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "test.dlq", + MaxDeliveries: 0, + }, + expectedErr: "invalid DLQ configuration: max deliveries must be greater than 0", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - handler, err := natsjsdlq.NewHandler(tt.deps, tt.config) - if tt.wantErr { - assert.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) + for _, tc := range testCases { + + t.Run(tc.name, func(t *testing.T) { + if tc.expectedErr == "" { + // Mock successful stream creation + streamConfig := jetstream.StreamConfig{ + Name: tc.config.StreamName, + Subjects: []string{tc.config.Subject}, + Storage: tc.config.Storage, + Retention: jetstream.WorkQueuePolicy, + } + suite.mockCM.On("EnsureStream", mock.Anything, streamConfig).Return(nil, nil).Once() + } + + handler, err := natsjsdlq.NewHandler(tc.deps, tc.config) + + if tc.expectedErr != "" { + assert.EqualError(t, err, tc.expectedErr) assert.Nil(t, handler) } else { assert.NoError(t, err) assert.NotNil(t, handler) + suite.mockCM.AssertExpectations(t) } }) } } -func TestHandlerPublishMessage(t *testing.T) { - tests := []struct { +func TestShouldDLQ(t *testing.T) { + suite := new(DLQHandlerTestSuite) + suite.SetupTest() + + testCases := []struct { name string - msg *nats.Msg - reason string - publishErr error - wantErr bool - validateHeader func(*testing.T, nats.Header) + maxDeliveries int + numDelivered uint64 + metadataErr error + expectedResult bool }{ { - name: "successful publish", - msg: &nats.Msg{ - Subject: "original.subject", - Data: []byte("test data"), - Header: nats.Header{"Original-Key": []string{"value"}}, - }, - reason: "test failure", - validateHeader: func(t *testing.T, h nats.Header) { - assert.Equal(t, "test failure", h.Get("DLQ-Reason")) - assert.Equal(t, "original.subject", h.Get("Original-Subject")) - assert.Equal(t, "value", h.Get("Original-Key")) - assert.NotEmpty(t, h.Get("DLQ-Timestamp")) - }, + name: "Should send to DLQ - equal to max", + maxDeliveries: 3, + numDelivered: 3, + metadataErr: nil, + expectedResult: true, }, { - name: "publish with no headers", - msg: &nats.Msg{ - Subject: "original.subject", - Data: []byte("test data"), - }, - reason: "test failure", - validateHeader: func(t *testing.T, h nats.Header) { - assert.Equal(t, "test failure", h.Get("DLQ-Reason")) - assert.Equal(t, "original.subject", h.Get("Original-Subject")) - }, + name: "Should send to DLQ - greater than max", + maxDeliveries: 3, + numDelivered: 5, + metadataErr: nil, + expectedResult: true, }, { - name: "publish error", - msg: &nats.Msg{ - Subject: "original.subject", - Data: []byte("test data"), - }, - reason: "test failure", - publishErr: errors.New("publish failed"), - wantErr: true, + name: "Should not send to DLQ - less than max", + maxDeliveries: 3, + numDelivered: 2, + metadataErr: nil, + expectedResult: false, + }, + { + name: "Metadata error", + maxDeliveries: 3, + numDelivered: 0, + metadataErr: errors.New("metadata error"), + expectedResult: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mock := &MockJetStreamContext{ - PublishMsgFunc: func(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) { - if tt.publishErr != nil { - return nil, tt.publishErr - } - return &nats.PubAck{}, nil - }, - } + // Setup for all test cases + validConfig := natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "test.dlq", + MaxDeliveries: 3, + Context: context.Background(), + } - var errorCaught error - handler, err := natsjsdlq.NewHandler( - natsjsdlq.Dependencies{JetStream: mock}, - natsjsdlq.Config{ - StreamName: "test_dlq", - Subject: "test.dlq", - MaxDeliveries: 3, - ErrorHandler: func(err error) { - errorCaught = err - }, - }, - ) - assert.NoError(t, err) + deps := natsjsdlq.Dependencies{ + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, + } - err = handler.PublishMessage(tt.msg, tt.reason) - if tt.wantErr { - assert.Error(t, err) - assert.NotNil(t, errorCaught) - } else { - assert.NoError(t, err) - assert.Nil(t, errorCaught) - assert.Len(t, mock.publishCalls, 1) - if tt.validateHeader != nil { - tt.validateHeader(t, mock.publishCalls[0].Header) - } - } - }) + // Mock stream creation once for all test cases + streamConfig := jetstream.StreamConfig{ + Name: validConfig.StreamName, + Subjects: []string{validConfig.Subject}, + Storage: jetstream.FileStorage, + Retention: jetstream.WorkQueuePolicy, } -} + suite.mockCM.On("EnsureStream", mock.Anything, streamConfig).Return(nil, nil).Once() -type MockMsg struct { - *nats.Msg - metadata *nats.MsgMetadata - metadataError error -} + handler, err := natsjsdlq.NewHandler(deps, validConfig) + assert.NoError(t, err) + assert.NotNil(t, handler) -func (m *MockMsg) Metadata() (*nats.MsgMetadata, error) { - if m.metadataError != nil { - return nil, m.metadataError + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a mock message with the configured metadata + mockMsg := new(MockMsg) + metadata := &nats.MsgMetadata{ + NumDelivered: tc.numDelivered, + } + + mockMsg.On("Metadata").Return(metadata, tc.metadataErr).Once() + + // Test the ShouldDLQ method + result := handler.ShouldDLQ(mockMsg) + assert.Equal(t, tc.expectedResult, result) + mockMsg.AssertExpectations(t) + }) } - return m.metadata, nil } -func TestHandlerShouldDLQ(t *testing.T) { - tests := []struct { - name string - msg func() *MockMsg - maxDeliveries int - want bool - wantErr bool +func TestPublishMessage(t *testing.T) { + suite := new(DLQHandlerTestSuite) + suite.SetupTest() + + testCases := []struct { + name string + originalMsg *nats.Msg + reason string + publishError error + expectedDLQSubject string + expectHeaders map[string]string }{ { - name: "should dlq when deliveries exceeded", - msg: func() *MockMsg { - return &MockMsg{ - Msg: &nats.Msg{ - Subject: "test.subject", - Data: []byte("test data"), - }, - metadata: &nats.MsgMetadata{ - NumDelivered: 4, - }, - } + name: "Successful DLQ publish with headers", + originalMsg: &nats.Msg{ + Subject: "original.subject", + Data: []byte("test data"), + Header: nats.Header{ + "Nats-Msg-Id": []string{"msg-123"}, + "Custom": []string{"value"}, + }, + }, + reason: "processing failed", + publishError: nil, + expectedDLQSubject: "test.dlq", + expectHeaders: map[string]string{ + "DLQ-Reason": "processing failed", + "Original-Subject": "original.subject", + "Original-Message-ID": "msg-123", + "Custom": "value", }, - maxDeliveries: 3, - want: true, }, { - name: "should not dlq when under max deliveries", - msg: func() *MockMsg { - return &MockMsg{ - Msg: &nats.Msg{ - Subject: "test.subject", - Data: []byte("test data"), - }, - metadata: &nats.MsgMetadata{ - NumDelivered: 2, - }, - } + name: "Successful DLQ publish without headers", + originalMsg: &nats.Msg{ + Subject: "original.subject", + Data: []byte("test data"), + }, + reason: "processing failed", + publishError: nil, + expectedDLQSubject: "test.dlq", + expectHeaders: map[string]string{ + "DLQ-Reason": "processing failed", + "Original-Subject": "original.subject", }, - maxDeliveries: 3, - want: false, }, { - name: "should not dlq on metadata error", - msg: func() *MockMsg { - return &MockMsg{ - Msg: &nats.Msg{ - Subject: "test.subject", - Data: []byte("test data"), - }, - metadataError: errors.New("metadata error"), - } + name: "Failed DLQ publish", + originalMsg: &nats.Msg{ + Subject: "original.subject", + Data: []byte("test data"), + }, + reason: "processing failed", + publishError: errors.New("publish error"), + expectedDLQSubject: "test.dlq", + expectHeaders: map[string]string{ + "DLQ-Reason": "processing failed", + "Original-Subject": "original.subject", }, - maxDeliveries: 3, - want: false, - wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create a message with mock metadata function - msg := tt.msg() + // Setup for all test cases + validConfig := natsjsdlq.Config{ + StreamName: "test-dlq", + Subject: "test.dlq", + MaxDeliveries: 3, + Context: context.Background(), + ErrorHandler: func(err error) {}, // No-op error handler + } + + deps := natsjsdlq.Dependencies{ + ConnectionManager: suite.mockCM, + Publisher: suite.mockPub, + } - var errorCaught error - handler, err := natsjsdlq.NewHandler(natsjsdlq.Dependencies{ - JetStream: &MockJetStreamContext{}, - }, natsjsdlq.Config{ - StreamName: "test_dlq", - Subject: "test.dlq", - MaxDeliveries: tt.maxDeliveries, - ErrorHandler: func(err error) { - errorCaught = err - }, - }) - assert.NoError(t, err) + // Mock stream creation once for all test cases + streamConfig := jetstream.StreamConfig{ + Name: validConfig.StreamName, + Subjects: []string{validConfig.Subject}, + Storage: jetstream.FileStorage, + Retention: jetstream.WorkQueuePolicy, + } + suite.mockCM.On("EnsureStream", mock.Anything, streamConfig).Return(nil, nil).Once() + + handler, err := natsjsdlq.NewHandler(deps, validConfig) + assert.NoError(t, err) + assert.NotNil(t, handler) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the publisher mock + suite.mockPub.On("Publish", mock.Anything, mock.MatchedBy(func(msg *nats.Msg) bool { + // Verify the DLQ message properties + if msg.Subject != tc.expectedDLQSubject { + return false + } - got := handler.ShouldDLQ(msg) - assert.Equal(t, tt.want, got) + if len(msg.Data) != len(tc.originalMsg.Data) { + return false + } + + // Verify headers + for key, expectedValue := range tc.expectHeaders { + if msg.Header.Get(key) == "" { + return false + } + + if key == "DLQ-Timestamp" { // Skip timestamp check as it's dynamic + continue + } + + if expectedValue != msg.Header.Get(key) { + return false + } + } + + // Verify timestamp header exists and is valid + timestampStr := msg.Header.Get("DLQ-Timestamp") + if timestampStr == "" { + return false + } + _, err := time.Parse(time.RFC3339, timestampStr) + return err == nil + })).Return(nil, tc.publishError).Once() + + // Test the PublishMessage method + err := handler.PublishMessage(tc.originalMsg, tc.reason) + + if tc.publishError != nil { + assert.Error(t, err) + assert.Equal(t, tc.publishError, err) + } else { + assert.NoError(t, err) + } - assert.Equal(t, tt.wantErr, errorCaught != nil) + suite.mockPub.AssertExpectations(t) }) } } diff --git a/messaging/natsjspub/natsjspub.go b/messaging/natsjspub/natsjspub.go new file mode 100644 index 0000000..949e2b4 --- /dev/null +++ b/messaging/natsjspub/natsjspub.go @@ -0,0 +1,219 @@ +package natsjspub + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/simiancreative/simiango/circuitbreaker" + "github.com/simiancreative/simiango/logger" + "github.com/simiancreative/simiango/messaging/natsjscm" +) + +type Logger interface { + Debugf(format string, args ...interface{}) +} + +type Publisher interface { + Publish(ctx context.Context, msg *nats.Msg) (*jetstream.PubAck, error) +} + +type JsonPublisher interface { + PublishJSON( + ctx context.Context, + data interface{}, + headers ...map[string]string, + ) (*jetstream.PubAck, error) +} + +// Config holds publisher configuration +type Config struct { + // Stream name to publish to + StreamName string + + // Subject to publish on + Subject string + + // CircuitBreaker configuration (optional) + CircuitBreaker *circuitbreaker.Config + + // Publish timeout (default 5s) + Timeout time.Duration + + // Message retention policy (optional, default is WorkQueuePolicy) + RetentionPolicy jetstream.RetentionPolicy +} + +// Dependencies for the publisher +type Dependencies struct { + // ConnectionManager for NATS + ConnectionManager *natsjscm.ConnectionManager +} + +// Publisher is a JetStream publisher with circuit breaker capabilities +type PublishManager struct { + config Config + cm *natsjscm.ConnectionManager + cb *circuitbreaker.CircuitBreaker + log Logger +} + +// NewPublisher creates a new JetStream publisher +func NewPublisher(deps Dependencies, config Config) (*PublishManager, error) { + // Validation + if deps.ConnectionManager == nil { + return nil, fmt.Errorf("connection manager is required") + } + + if config.StreamName == "" { + return nil, fmt.Errorf("stream name is required") + } + + if config.Subject == "" { + return nil, fmt.Errorf("subject is required") + } + + if config.Timeout <= 0 { + config.Timeout = 5 * time.Second + } + + pub := &PublishManager{ + config: config, + cm: deps.ConnectionManager, + log: logger.New(), + } + + // Initialize circuit breaker if configured + if config.CircuitBreaker != nil { + cb, err := circuitbreaker.New(*config.CircuitBreaker) + if err != nil { + return nil, fmt.Errorf("failed to create circuit breaker: %w", err) + } + pub.cb = cb + } + + // Ensure the stream exists + err := pub.ensureStream(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to ensure stream: %w", err) + } + + return pub, nil +} + +// ensureStream makes sure the configured stream exists +func (p *PublishManager) ensureStream(ctx context.Context) error { + // Get JetStream connection + if !p.cm.IsConnected() { + if err := p.cm.Connect(); err != nil { + return fmt.Errorf("failed to connect to NATS: %w", err) + } + } + + // Set default retention policy if not configured + retentionPolicy := p.config.RetentionPolicy + if retentionPolicy == 0 { + retentionPolicy = jetstream.WorkQueuePolicy + } + + // Create stream config + streamConfig := jetstream.StreamConfig{ + Name: p.config.StreamName, + Subjects: []string{p.config.Subject}, + Retention: retentionPolicy, + } + + // Ensure stream exists + _, err := p.cm.EnsureStream(ctx, streamConfig) + return err +} + +// Publish publishes a message to the configured subject +func (p *PublishManager) Publish( + ctx context.Context, + msg *nats.Msg, +) (*jetstream.PubAck, error) { + // Check if circuit breaker allows the request + if p.cb != nil && !p.cb.Allow() { + return nil, fmt.Errorf("circuit breaker is open") + } + + // Record attempt start if circuit breaker is configured + var cbRecorded bool + if p.cb != nil { + cbRecorded = p.cb.RecordStart() + if !cbRecorded { + return nil, fmt.Errorf("circuit breaker rejected request") + } + } + + // Get JetStream connection + js := p.cm.GetJetStream() + if js == nil { + // Record failure if circuit breaker is configured + if p.cb != nil && cbRecorded { + p.cb.RecordResult(false) + } + return nil, fmt.Errorf("jetstream connection not available") + } + + // Apply timeout + ctxWithTimeout, cancel := context.WithTimeout(ctx, p.config.Timeout) + defer cancel() + + // Publish to JetStream + ack, err := js.PublishMsg(ctxWithTimeout, msg) + + // Record result in circuit breaker + if p.cb != nil && cbRecorded { + p.cb.RecordResult(err == nil) + } + + if err != nil { + p.log.Debugf("failed to publish message: %s", err) + } + + return ack, err +} + +// PublishJSON publishes a JSON-serializable object to the configured subject +func (p *PublishManager) PublishJSON( + ctx context.Context, + data interface{}, + headers ...map[string]string, +) (*jetstream.PubAck, error) { + msg := &nats.Msg{} + + // Marshal data to JSON + jsonData, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSON: %w", err) + } + + for _, header := range headers { + for k, v := range header { + msg.Header.Add(k, v) + } + } + + // Add content-type header if not already present + if val := msg.Header.Get("content-type"); val == "" { + msg.Header.Add("content-type", "application/json") + } + + msg.Data = jsonData + msg.Subject = p.config.Subject + + return p.Publish(ctx, msg) +} + +// Close cleans up resources +func (p *PublishManager) Close() error { + if p.cb != nil { + p.cb.Reset() + } + return nil +}