diff --git a/internal/connectionmanager/connection_manager.go b/internal/connectionmanager/connection_manager.go index 1967c46..d6ccd73 100644 --- a/internal/connectionmanager/connection_manager.go +++ b/internal/connectionmanager/connection_manager.go @@ -23,6 +23,13 @@ type ConnectionManager struct { reconnectionCount uint reconnectionCountMu *sync.Mutex dispatcher *dispatcher.Dispatcher + + // universalNotifyBlockingReceiver receives block signal from underlying + // connection which are broadcasted to all publisherNotifyBlockingReceivers + universalNotifyBlockingReceiver chan amqp.Blocking + universalNotifyBlockingReceiverUsed bool + publisherNotifyBlockingReceiversMu *sync.RWMutex + publisherNotifyBlockingReceivers []chan amqp.Blocking } type Resolver interface { @@ -62,17 +69,20 @@ func NewConnectionManager(resolver Resolver, conf amqp.Config, log logger.Logger } connManager := ConnectionManager{ - logger: log, - resolver: resolver, - connection: conn, - amqpConfig: conf, - connectionMu: &sync.RWMutex{}, - ReconnectInterval: reconnectInterval, - reconnectionCount: 0, - reconnectionCountMu: &sync.Mutex{}, - dispatcher: dispatcher.NewDispatcher(), + logger: log, + resolver: resolver, + connection: conn, + amqpConfig: conf, + connectionMu: &sync.RWMutex{}, + ReconnectInterval: reconnectInterval, + reconnectionCount: 0, + reconnectionCountMu: &sync.Mutex{}, + dispatcher: dispatcher.NewDispatcher(), + universalNotifyBlockingReceiver: make(chan amqp.Blocking), + publisherNotifyBlockingReceiversMu: &sync.RWMutex{}, } go connManager.startNotifyClose() + go connManager.readUniversalBlockReceiver() return &connManager, nil } diff --git a/internal/connectionmanager/safe_wraps.go b/internal/connectionmanager/safe_wraps.go index 6a6abbc..6e9dcde 100644 --- a/internal/connectionmanager/safe_wraps.go +++ b/internal/connectionmanager/safe_wraps.go @@ -8,10 +8,43 @@ import ( func (connManager *ConnectionManager) NotifyBlockedSafe( receiver chan amqp.Blocking, ) chan amqp.Blocking { - connManager.connectionMu.RLock() - defer connManager.connectionMu.RUnlock() + connManager.connectionMu.Lock() + defer connManager.connectionMu.Unlock() - return connManager.connection.NotifyBlocked( - receiver, - ) + // add receiver to connection manager. + connManager.publisherNotifyBlockingReceiversMu.Lock() + connManager.publisherNotifyBlockingReceivers = append(connManager.publisherNotifyBlockingReceivers, receiver) + connManager.publisherNotifyBlockingReceiversMu.Unlock() + + if !connManager.universalNotifyBlockingReceiverUsed { + connManager.connection.NotifyBlocked( + connManager.universalNotifyBlockingReceiver, + ) + connManager.universalNotifyBlockingReceiverUsed = true + } + + return receiver +} + +// readUniversalBlockReceiver reads on universal blocking receiver and broadcasts event to all blocking receivers of +// connection manager. +func (connManager *ConnectionManager) readUniversalBlockReceiver() { + for b := range connManager.universalNotifyBlockingReceiver { + connManager.publisherNotifyBlockingReceiversMu.RLock() + for _, br := range connManager.publisherNotifyBlockingReceivers { + br <- b + } + connManager.publisherNotifyBlockingReceiversMu.RUnlock() + } +} + +func (connManager *ConnectionManager) RemovePublisherBlockingReceiver(receiver chan amqp.Blocking) { + connManager.publisherNotifyBlockingReceiversMu.Lock() + for i, br := range connManager.publisherNotifyBlockingReceivers { + if br == receiver { + connManager.publisherNotifyBlockingReceivers = append(connManager.publisherNotifyBlockingReceivers[:i], connManager.publisherNotifyBlockingReceivers[i+1:]...) + } + } + connManager.publisherNotifyBlockingReceiversMu.Unlock() + close(receiver) } diff --git a/publish.go b/publish.go index 06f9cb0..49cb16a 100644 --- a/publish.go +++ b/publish.go @@ -58,6 +58,8 @@ type Publisher struct { notifyPublishHandler func(p Confirmation) options PublisherOptions + + blockings chan amqp.Blocking } type PublisherConfirmation []*amqp.DeferredConfirmation @@ -286,6 +288,7 @@ func (publisher *Publisher) Close() { publisher.options.Logger.Warnf("error while closing the channel: %v", err) } publisher.options.Logger.Infof("closing publisher...") + publisher.connManager.RemovePublisherBlockingReceiver(publisher.blockings) go func() { publisher.closeConnectionToManagerCh <- struct{}{} }() diff --git a/publish_flow_block.go b/publish_flow_block.go index b978a21..6fbc439 100644 --- a/publish_flow_block.go +++ b/publish_flow_block.go @@ -26,6 +26,7 @@ func (publisher *Publisher) startNotifyFlowHandler() { func (publisher *Publisher) startNotifyBlockedHandler() { blockings := publisher.connManager.NotifyBlockedSafe(make(chan amqp.Blocking)) publisher.disablePublishDueToBlockedMu.Lock() + publisher.blockings = blockings publisher.disablePublishDueToBlocked = false publisher.disablePublishDueToBlockedMu.Unlock()