From 079a2dcac235d1001e4db686090019b5aef90c11 Mon Sep 17 00:00:00 2001 From: Chris Gianelloni Date: Sat, 27 Apr 2024 17:45:45 -0400 Subject: [PATCH] fix: guard operating on nil resource fields Signed-off-by: Chris Gianelloni --- blockfetch.go | 6 +++++- chainsync.go | 19 +++++++++++++++---- chainsync/chainsync.go | 5 ++++- cmd/node/main.go | 6 ++++-- config.go | 13 ++++++++++--- connection.go | 7 ++++++- listener.go | 14 +++++++++++--- mempool/consumer.go | 36 ++++++++++++++++++++++++++---------- mempool/mempool.go | 10 ++++++++-- node.go | 13 +++++++++++-- outbound.go | 38 +++++++++++++++++++++++++++++++------- peersharing.go | 5 ++++- txsubmission.go | 31 +++++++++++++++++++++++++------ 13 files changed, 160 insertions(+), 43 deletions(-) diff --git a/blockfetch.go b/blockfetch.go index a7175fc..a056420 100644 --- a/blockfetch.go +++ b/blockfetch.go @@ -34,7 +34,11 @@ func (n *Node) blockfetchClientConnOpts() []oblockfetch.BlockFetchOptionFunc { } } -func (n *Node) blockfetchServerRequestRange(ctx oblockfetch.CallbackContext, start ocommon.Point, end ocommon.Point) error { +func (n *Node) blockfetchServerRequestRange( + ctx oblockfetch.CallbackContext, + start ocommon.Point, + end ocommon.Point, +) error { // TODO: check if we have requested block range available and send NoBlocks if not // Start async process to send requested block range go func() { diff --git a/chainsync.go b/chainsync.go index 4b48fb9..bf77889 100644 --- a/chainsync.go +++ b/chainsync.go @@ -58,7 +58,10 @@ func (n *Node) chainsyncClientStart(connId ouroboros.ConnectionId) error { return nil } -func (n *Node) chainsyncServerFindIntersect(ctx ochainsync.CallbackContext, points []ocommon.Point) (ocommon.Point, ochainsync.Tip, error) { +func (n *Node) chainsyncServerFindIntersect( + ctx ochainsync.CallbackContext, + points []ocommon.Point, +) (ocommon.Point, ochainsync.Tip, error) { var retPoint ocommon.Point var retTip ochainsync.Tip // Find intersection @@ -95,9 +98,14 @@ func (n *Node) chainsyncServerFindIntersect(ctx ochainsync.CallbackContext, poin return retPoint, retTip, nil } -func (n *Node) chainsyncServerRequestNext(ctx ochainsync.CallbackContext) error { +func (n *Node) chainsyncServerRequestNext( + ctx ochainsync.CallbackContext, +) error { // Create/retrieve chainsync state for connection - clientState := n.chainsyncState.AddClient(ctx.ConnectionId, n.chainsyncState.Tip()) + clientState := n.chainsyncState.AddClient( + ctx.ConnectionId, + n.chainsyncState.Tip(), + ) if clientState.NeedsInitialRollback { err := ctx.Server.RollBackward( clientState.Cursor.ToTip().Point, @@ -137,7 +145,10 @@ func (n *Node) chainsyncServerRequestNext(ctx ochainsync.CallbackContext) error return nil } -func (n *Node) chainsyncServerSendNext(ctx ochainsync.CallbackContext, block chainsync.ChainsyncBlock) error { +func (n *Node) chainsyncServerSendNext( + ctx ochainsync.CallbackContext, + block chainsync.ChainsyncBlock, +) error { var err error if block.Rollback { err = ctx.Server.RollBackward( diff --git a/chainsync/chainsync.go b/chainsync/chainsync.go index 49333a3..509622f 100644 --- a/chainsync/chainsync.go +++ b/chainsync/chainsync.go @@ -100,7 +100,10 @@ func (s *State) RecentBlocks() []ChainsyncBlock { return s.recentBlocks[:] } -func (s *State) AddClient(connId connection.ConnectionId, cursor ChainsyncPoint) *ChainsyncClientState { +func (s *State) AddClient( + connId connection.ConnectionId, + cursor ChainsyncPoint, +) *ChainsyncClientState { s.Lock() defer s.Unlock() // Create initial chainsync state for connection diff --git a/cmd/node/main.go b/cmd/node/main.go index 5bee79b..2201d10 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -61,8 +61,10 @@ func main() { } // Global flags - rootCmd.PersistentFlags().BoolVarP(&globalFlags.debug, "debug", "D", false, "enable debug logging") - rootCmd.PersistentFlags().BoolVarP(&globalFlags.version, "version", "", false, "show version and exit") + rootCmd.PersistentFlags(). + BoolVarP(&globalFlags.debug, "debug", "D", false, "enable debug logging") + rootCmd.PersistentFlags(). + BoolVarP(&globalFlags.version, "version", "", false, "show version and exit") // Execute cobra command if err := rootCmd.Execute(); err != nil { diff --git a/config.go b/config.go index 9dbc75b..6b769bf 100644 --- a/config.go +++ b/config.go @@ -48,7 +48,10 @@ func (n *Node) configPopulateNetworkMagic() error { func (n *Node) configValidate() error { if n.config.networkMagic == 0 { - return fmt.Errorf("invalid network magic value: %d", n.config.networkMagic) + return fmt.Errorf( + "invalid network magic value: %d", + n.config.networkMagic, + ) } if len(n.config.listeners) == 0 { return fmt.Errorf("no listeners defined") @@ -60,7 +63,9 @@ func (n *Node) configValidate() error { if listener.ListenNetwork != "" && listener.ListenAddress != "" { continue } - return fmt.Errorf("listener must provide net.Listener or listen network/address values") + return fmt.Errorf( + "listener must provide net.Listener or listen network/address values", + ) } return nil } @@ -126,7 +131,9 @@ func WithOutboundSourcePort(port int) ConfigOptionFunc { } // WithTopologyConfig specifies an ouroboros.TopologyConfig to use for outbound peers -func WithTopologyConfig(topologyConfig *ouroboros.TopologyConfig) ConfigOptionFunc { +func WithTopologyConfig( + topologyConfig *ouroboros.TopologyConfig, +) ConfigOptionFunc { return func(c *Config) { c.topologyConfig = topologyConfig } diff --git a/connection.go b/connection.go index b3dffda..110b1e8 100644 --- a/connection.go +++ b/connection.go @@ -24,7 +24,12 @@ import ( func socketControl(network, address string, c syscall.RawConn) error { var innerErr error err := c.Control(func(fd uintptr) { - err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + err := unix.SetsockoptInt( + int(fd), + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ) if err != nil { innerErr = err return diff --git a/listener.go b/listener.go index a62e4f3..e3cdb66 100644 --- a/listener.go +++ b/listener.go @@ -41,7 +41,11 @@ func (n *Node) startListener(l ListenerConfig) error { if l.ReuseAddress { listenConfig.Control = socketControl } - listener, err := listenConfig.Listen(context.Background(), l.ListenNetwork, l.ListenAddress) + listener, err := listenConfig.Listen( + context.Background(), + l.ListenNetwork, + l.ListenAddress, + ) if err != nil { return fmt.Errorf("failed to open listening socket: %s", err) } @@ -108,7 +112,9 @@ func (n *Node) startListener(l ListenerConfig) error { n.config.logger.Error(fmt.Sprintf("accept failed: %s", err)) continue } - n.config.logger.Info(fmt.Sprintf("accepted connection from %s", conn.RemoteAddr())) + n.config.logger.Info( + fmt.Sprintf("accepted connection from %s", conn.RemoteAddr()), + ) // Setup Ouroboros connection connOpts := append( defaultConnOpts, @@ -116,7 +122,9 @@ func (n *Node) startListener(l ListenerConfig) error { ) oConn, err := ouroboros.NewConnection(connOpts...) if err != nil { - n.config.logger.Error(fmt.Sprintf("failed to setup connection: %s", err)) + n.config.logger.Error( + fmt.Sprintf("failed to setup connection: %s", err), + ) continue } // Add to connection manager diff --git a/mempool/consumer.go b/mempool/consumer.go index a208fd1..bece148 100644 --- a/mempool/consumer.go +++ b/mempool/consumer.go @@ -32,6 +32,9 @@ func newConsumer() *MempoolConsumer { } func (m *MempoolConsumer) NextTx(blocking bool) *MempoolTransaction { + if m == nil { + return nil + } var ret *MempoolTransaction if blocking { // Wait until a transaction is available @@ -59,28 +62,41 @@ func (m *MempoolConsumer) NextTx(blocking bool) *MempoolTransaction { } func (m *MempoolConsumer) GetTxFromCache(hash string) *MempoolTransaction { - m.cacheMutex.Lock() - defer m.cacheMutex.Unlock() - return m.cache[hash] + if m != nil { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + return m.cache[hash] + } + var ret *MempoolTransaction + return ret } func (m *MempoolConsumer) ClearCache() { - m.cacheMutex.Lock() - defer m.cacheMutex.Unlock() - m.cache = make(map[string]*MempoolTransaction) + if m != nil { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + m.cache = make(map[string]*MempoolTransaction) + } } func (m *MempoolConsumer) RemoveTxFromCache(hash string) { - m.cacheMutex.Lock() - defer m.cacheMutex.Unlock() - delete(m.cache, hash) + if m != nil { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + delete(m.cache, hash) + } } func (m *MempoolConsumer) stop() { - close(m.txChan) + if m != nil { + close(m.txChan) + } } func (m *MempoolConsumer) pushTx(tx *MempoolTransaction, wait bool) bool { + if m == nil { + return false + } if wait { // Block on write to channel m.txChan <- tx diff --git a/mempool/mempool.go b/mempool/mempool.go index 6771025..d91854b 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -123,7 +123,10 @@ func (m *Mempool) removeExpired() { if tx.LastSeen.Before(expiredBefore) { m.removeTransaction(tx.Hash) m.logger.Debug( - fmt.Sprintf("removed expired transaction %s from mempool", tx.Hash), + fmt.Sprintf( + "removed expired transaction %s from mempool", + tx.Hash, + ), ) } } @@ -148,7 +151,10 @@ func (m *Mempool) AddTransaction(tx MempoolTransaction) error { if existingTx != nil { tx.LastSeen = time.Now() m.logger.Debug( - fmt.Sprintf("updated last seen for transaction %s in mempool", tx.Hash), + fmt.Sprintf( + "updated last seen for transaction %s in mempool", + tx.Hash, + ), ) return nil } diff --git a/node.go b/node.go index 55bf5cf..3d78889 100644 --- a/node.go +++ b/node.go @@ -72,9 +72,18 @@ func (n *Node) Run() error { select {} } -func (n *Node) connectionManagerConnClosed(connId ouroboros.ConnectionId, err error) { +func (n *Node) connectionManagerConnClosed( + connId ouroboros.ConnectionId, + err error, +) { if err != nil { - n.config.logger.Error(fmt.Sprintf("unexpected connection failure: %s: %s", connId.String(), err)) + n.config.logger.Error( + fmt.Sprintf( + "unexpected connection failure: %s: %s", + connId.String(), + err, + ), + ) } else { n.config.logger.Info(fmt.Sprintf("connection closed: %s", connId.String())) } diff --git a/outbound.go b/outbound.go index 97bbd11..531ccb9 100644 --- a/outbound.go +++ b/outbound.go @@ -42,16 +42,25 @@ type outboundPeer struct { func (n *Node) startOutboundConnections() { var tmpHosts []string for _, host := range n.config.topologyConfig.Producers { - tmpHosts = append(tmpHosts, net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port)))) + tmpHosts = append( + tmpHosts, + net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port))), + ) } for _, localRoot := range n.config.topologyConfig.LocalRoots { for _, host := range localRoot.AccessPoints { - tmpHosts = append(tmpHosts, net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port)))) + tmpHosts = append( + tmpHosts, + net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port))), + ) } } for _, publicRoot := range n.config.topologyConfig.PublicRoots { for _, host := range publicRoot.AccessPoints { - tmpHosts = append(tmpHosts, net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port)))) + tmpHosts = append( + tmpHosts, + net.JoinHostPort(host.Address, strconv.Itoa(int(host.Port))), + ) } } // Start outbound connections @@ -60,7 +69,11 @@ func (n *Node) startOutboundConnections() { go func(peer outboundPeer) { if err := n.createOutboundConnection(peer); err != nil { n.config.logger.Error( - fmt.Sprintf("failed to establish connection to %s: %s", peer.Address, err), + fmt.Sprintf( + "failed to establish connection to %s: %s", + peer.Address, + err, + ), ) go n.reconnectOutboundConnection(peer) } @@ -77,7 +90,10 @@ func (n *Node) createOutboundConnection(peer outboundPeer) error { if n.config.outboundSourcePort > 0 { // Setup connection to use our listening port as the source port // This is required for peer sharing to be useful - clientAddr, _ = net.ResolveTCPAddr("tcp", fmt.Sprintf(":%d", n.config.outboundSourcePort)) + clientAddr, _ = net.ResolveTCPAddr( + "tcp", + fmt.Sprintf(":%d", n.config.outboundSourcePort), + ) dialer.LocalAddr = clientAddr dialer.Control = socketControl } @@ -167,12 +183,20 @@ func (n *Node) reconnectOutboundConnection(peer outboundPeer) { peer.ReconnectDelay = peer.ReconnectDelay * reconnectBackoffFactor } n.config.logger.Info( - fmt.Sprintf("delaying %s before reconnecting to %s", peer.ReconnectDelay, peer.Address), + fmt.Sprintf( + "delaying %s before reconnecting to %s", + peer.ReconnectDelay, + peer.Address, + ), ) time.Sleep(peer.ReconnectDelay) if err := n.createOutboundConnection(peer); err != nil { n.config.logger.Error( - fmt.Sprintf("failed to establish connection to %s: %s", peer.Address, err), + fmt.Sprintf( + "failed to establish connection to %s: %s", + peer.Address, + err, + ), ) continue } diff --git a/peersharing.go b/peersharing.go index 87e503e..9d46f74 100644 --- a/peersharing.go +++ b/peersharing.go @@ -30,7 +30,10 @@ func (n *Node) peersharingClientConnOpts() []opeersharing.PeerSharingOptionFunc } } -func (n *Node) peersharingShareRequest(ctx opeersharing.CallbackContext, amount int) ([]opeersharing.PeerAddress, error) { +func (n *Node) peersharingShareRequest( + ctx opeersharing.CallbackContext, + amount int, +) ([]opeersharing.PeerAddress, error) { // TODO: add hooks for getting peers to share return []opeersharing.PeerAddress{}, nil } diff --git a/txsubmission.go b/txsubmission.go index 2328304..f64290a 100644 --- a/txsubmission.go +++ b/txsubmission.go @@ -64,9 +64,14 @@ func (n *Node) txsubmissionServerInit(ctx otxsubmission.CallbackContext) error { for { // Request available TX IDs (era and TX hash) and sizes // We make the request blocking to avoid looping on our side - txIds, err := ctx.Server.RequestTxIds(true, txsubmissionRequestTxIdsCount) + txIds, err := ctx.Server.RequestTxIds( + true, + txsubmissionRequestTxIdsCount, + ) if err != nil { - n.config.logger.Error(fmt.Sprintf("failed to request TxIds: %s", err)) + n.config.logger.Error( + fmt.Sprintf("failed to request TxIds: %s", err), + ) return } if len(txIds) > 0 { @@ -78,14 +83,24 @@ func (n *Node) txsubmissionServerInit(ctx otxsubmission.CallbackContext) error { // Request TX content for TxIds from above txs, err := ctx.Server.RequestTxs(requestTxIds) if err != nil { - n.config.logger.Error(fmt.Sprintf("failed to request Txs: %s", err)) + n.config.logger.Error( + fmt.Sprintf("failed to request Txs: %s", err), + ) return } for _, txBody := range txs { // Decode TX from CBOR - tx, err := ledger.NewTransactionFromCbor(uint(txBody.EraId), txBody.TxBody) + tx, err := ledger.NewTransactionFromCbor( + uint(txBody.EraId), + txBody.TxBody, + ) if err != nil { - n.config.logger.Error(fmt.Sprintf("failed to parse transaction CBOR: %s", err)) + n.config.logger.Error( + fmt.Sprintf( + "failed to parse transaction CBOR: %s", + err, + ), + ) return } n.config.logger.Debug( @@ -104,7 +119,11 @@ func (n *Node) txsubmissionServerInit(ctx otxsubmission.CallbackContext) error { ) if err != nil { n.config.logger.Error( - fmt.Sprintf("failed to add TX %s to mempool: %s", tx.Hash(), err), + fmt.Sprintf( + "failed to add TX %s to mempool: %s", + tx.Hash(), + err, + ), ) return }