diff --git a/mempool/consumer.go b/mempool/consumer.go new file mode 100644 index 0000000..a208fd1 --- /dev/null +++ b/mempool/consumer.go @@ -0,0 +1,97 @@ +// Copyright 2024 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mempool + +import ( + "sync" +) + +type MempoolConsumer struct { + txChan chan *MempoolTransaction + cache map[string]*MempoolTransaction + cacheMutex sync.Mutex +} + +func newConsumer() *MempoolConsumer { + return &MempoolConsumer{ + txChan: make(chan *MempoolTransaction), + cache: make(map[string]*MempoolTransaction), + } +} + +func (m *MempoolConsumer) NextTx(blocking bool) *MempoolTransaction { + var ret *MempoolTransaction + if blocking { + // Wait until a transaction is available + tmpTx, ok := <-m.txChan + if ok { + ret = tmpTx + } + } else { + select { + case tmpTx, ok := <-m.txChan: + if ok { + ret = tmpTx + } + default: + // No transaction available + } + } + if ret != nil { + // Add transaction to cache + m.cacheMutex.Lock() + m.cache[ret.Hash] = ret + m.cacheMutex.Unlock() + } + return ret +} + +func (m *MempoolConsumer) GetTxFromCache(hash string) *MempoolTransaction { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + return m.cache[hash] +} + +func (m *MempoolConsumer) ClearCache() { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + m.cache = make(map[string]*MempoolTransaction) +} + +func (m *MempoolConsumer) RemoveTxFromCache(hash string) { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + delete(m.cache, hash) +} + +func (m *MempoolConsumer) stop() { + close(m.txChan) +} + +func (m *MempoolConsumer) pushTx(tx *MempoolTransaction, wait bool) bool { + if wait { + // Block on write to channel + m.txChan <- tx + return true + } else { + // Return immediately if we can't write to channel + select { + case m.txChan <- tx: + return true + default: + return false + } + } +} diff --git a/mempool/mempool.go b/mempool/mempool.go new file mode 100644 index 0000000..6771025 --- /dev/null +++ b/mempool/mempool.go @@ -0,0 +1,225 @@ +// Copyright 2024 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mempool + +import ( + "fmt" + "io" + "log/slog" + "slices" + "sync" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" +) + +const ( + txsubmissionMempoolExpiration = 1 * time.Hour + txSubmissionMempoolExpirationPeriod = 1 * time.Minute +) + +type MempoolTransaction struct { + Hash string + Type uint + Cbor []byte + LastSeen time.Time +} + +type Mempool struct { + sync.Mutex + logger *slog.Logger + consumers map[ouroboros.ConnectionId]*MempoolConsumer + consumersMutex sync.Mutex + consumerIndex map[ouroboros.ConnectionId]int + consumerIndexMutex sync.Mutex + transactions []*MempoolTransaction +} + +func NewMempool(logger *slog.Logger) *Mempool { + m := &Mempool{ + consumers: make(map[ouroboros.ConnectionId]*MempoolConsumer), + } + if logger == nil { + // Create logger to throw away logs + // We do this so we don't have to add guards around every log operation + m.logger = slog.New(slog.NewJSONHandler(io.Discard, nil)) + } + // TODO: replace this with purging based on on-chain TXs + // Schedule initial mempool expired cleanup + m.scheduleRemoveExpired() + return m +} + +func (m *Mempool) AddConsumer(connId ouroboros.ConnectionId) *MempoolConsumer { + // Create consumer + m.consumersMutex.Lock() + defer m.consumersMutex.Unlock() + consumer := newConsumer() + m.consumers[connId] = consumer + // Start goroutine to send existing TXs to consumer + go func(consumer *MempoolConsumer) { + for { + m.Lock() + m.consumerIndexMutex.Lock() + nextTxIdx, ok := m.consumerIndex[connId] + if !ok { + // Our consumer has disappeared + return + } + if nextTxIdx >= len(m.transactions) { + // We've reached the current end of the mempool + return + } + nextTx := m.transactions[nextTxIdx] + if consumer.pushTx(nextTx, true) { + nextTxIdx++ + m.consumerIndex[connId] = nextTxIdx + } + m.consumerIndexMutex.Unlock() + m.Unlock() + } + }(consumer) + return consumer +} + +func (m *Mempool) RemoveConsumer(connId ouroboros.ConnectionId) { + m.consumersMutex.Lock() + m.consumerIndexMutex.Lock() + defer func() { + m.consumerIndexMutex.Unlock() + m.consumersMutex.Unlock() + }() + if consumer, ok := m.consumers[connId]; ok { + consumer.stop() + delete(m.consumers, connId) + delete(m.consumerIndex, connId) + } +} + +func (m *Mempool) Consumer(connId ouroboros.ConnectionId) *MempoolConsumer { + m.consumersMutex.Lock() + defer m.consumersMutex.Unlock() + return m.consumers[connId] +} + +// TODO: replace this with purging based on on-chain TXs +func (m *Mempool) removeExpired() { + m.Lock() + defer m.Unlock() + expiredBefore := time.Now().Add(-txsubmissionMempoolExpiration) + for _, tx := range m.transactions { + if tx.LastSeen.Before(expiredBefore) { + m.removeTransaction(tx.Hash) + m.logger.Debug( + fmt.Sprintf("removed expired transaction %s from mempool", tx.Hash), + ) + } + } + m.scheduleRemoveExpired() +} + +func (m *Mempool) scheduleRemoveExpired() { + _ = time.AfterFunc(txSubmissionMempoolExpirationPeriod, m.removeExpired) +} + +func (m *Mempool) AddTransaction(tx MempoolTransaction) error { + m.Lock() + m.consumersMutex.Lock() + m.consumerIndexMutex.Lock() + defer func() { + m.consumerIndexMutex.Unlock() + m.consumersMutex.Unlock() + m.Unlock() + }() + // Update last seen for existing TX + existingTx := m.getTransaction(tx.Hash) + if existingTx != nil { + tx.LastSeen = time.Now() + m.logger.Debug( + fmt.Sprintf("updated last seen for transaction %s in mempool", tx.Hash), + ) + return nil + } + // Add transaction record + m.transactions = append(m.transactions, &tx) + m.logger.Debug( + fmt.Sprintf("added transaction %s to mempool", tx.Hash), + ) + // Send new TX to consumers that are ready for it + newTxIdx := len(m.transactions) - 1 + for connId, consumerIdx := range m.consumerIndex { + if consumerIdx == newTxIdx { + consumer := m.consumers[connId] + if consumer.pushTx(&tx, false) { + consumerIdx++ + m.consumerIndex[connId] = consumerIdx + } + } + } + return nil +} + +func (m *Mempool) GetTransaction(txHash string) (MempoolTransaction, bool) { + m.Lock() + defer m.Unlock() + ret := m.getTransaction(txHash) + if ret == nil { + return MempoolTransaction{}, false + } + return *ret, true +} + +func (m *Mempool) getTransaction(txHash string) *MempoolTransaction { + for _, tx := range m.transactions { + if tx.Hash == txHash { + return tx + } + } + return nil +} + +func (m *Mempool) RemoveTransaction(hash string) { + m.Lock() + defer m.Unlock() + if m.removeTransaction(hash) { + m.logger.Debug( + fmt.Sprintf("removed transaction %s from mempool", hash), + ) + } +} + +func (m *Mempool) removeTransaction(hash string) bool { + for txIdx, tx := range m.transactions { + if tx.Hash == hash { + m.consumerIndexMutex.Lock() + m.transactions = slices.Delete( + m.transactions, + txIdx, + txIdx+1, + ) + // Update consumer indexes to reflect removed TX + for connId, consumerIdx := range m.consumerIndex { + // Decrement consumer index if the consumer has reached the removed TX + if consumerIdx >= txIdx { + consumerIdx-- + } + m.consumerIndex[connId] = consumerIdx + } + m.consumerIndexMutex.Unlock() + return true + } + } + return false +} diff --git a/node.go b/node.go index c76e14a..55bf5cf 100644 --- a/node.go +++ b/node.go @@ -19,6 +19,7 @@ import ( "sync" "github.com/blinklabs-io/node/chainsync" + "github.com/blinklabs-io/node/mempool" ouroboros "github.com/blinklabs-io/gouroboros" ) @@ -29,12 +30,14 @@ type Node struct { chainsyncState *chainsync.State outboundConns map[ouroboros.ConnectionId]outboundPeer outboundConnsMutex sync.Mutex + mempool *mempool.Mempool } func New(cfg Config) (*Node, error) { n := &Node{ config: cfg, chainsyncState: chainsync.NewState(), + mempool: mempool.NewMempool(cfg.logger), outboundConns: make(map[ouroboros.ConnectionId]outboundPeer), } if err := n.configPopulateNetworkMagic(); err != nil { @@ -64,7 +67,6 @@ func (n *Node) Run() error { n.connManager.AddHostsFromTopology(n.config.topologyConfig) } n.startOutboundConnections() - // TODO // Wait forever select {} @@ -84,6 +86,8 @@ func (n *Node) connectionManagerConnClosed(connId ouroboros.ConnectionId, err er n.connManager.RemoveConnection(connId) // Remove any chainsync client state n.chainsyncState.RemoveClient(connId) + // Remove mempool consumer + n.mempool.RemoveConsumer(connId) // Outbound connections n.outboundConnsMutex.Lock() if peer, ok := n.outboundConns[connId]; ok { diff --git a/outbound.go b/outbound.go index 0d56743..97bbd11 100644 --- a/outbound.go +++ b/outbound.go @@ -69,15 +69,6 @@ func (n *Node) startOutboundConnections() { } -/* -func (n *Node) getOutboundConn(connId ouroboros.ConnectionId) (outboundPeer, bool) { - n.outboundConnsMutex.Lock() - defer n.outboundConnsMutex.Unlock() - conn, ok := n.outboundConns[connId] - return conn, ok -} -*/ - func (n *Node) createOutboundConnection(peer outboundPeer) error { var clientAddr net.Addr dialer := net.Dialer{ @@ -161,7 +152,10 @@ func (n *Node) createOutboundConnection(peer outboundPeer) error { return err } } - // TODO: start txsubmission client + // Start txsubmission client + if err := n.txsubmissionClientStart(oConn.Id()); err != nil { + return err + } return nil } diff --git a/txsubmission.go b/txsubmission.go index f8efe63..2328304 100644 --- a/txsubmission.go +++ b/txsubmission.go @@ -15,10 +15,16 @@ package node import ( + "encoding/hex" "fmt" + "log/slog" + "time" + ouroboros "github.com/blinklabs-io/gouroboros" "github.com/blinklabs-io/gouroboros/ledger" + "github.com/blinklabs-io/gouroboros/protocol/txsubmission" otxsubmission "github.com/blinklabs-io/gouroboros/protocol/txsubmission" + "github.com/blinklabs-io/node/mempool" ) const ( @@ -33,18 +39,25 @@ func (n *Node) txsubmissionServerConnOpts() []otxsubmission.TxSubmissionOptionFu func (n *Node) txsubmissionClientConnOpts() []otxsubmission.TxSubmissionOptionFunc { return []otxsubmission.TxSubmissionOptionFunc{ - // TODO - /* - txsubmission.WithRequestTxIdsFunc( - n.txsubmissionClientRequestTxIds, - ), - txsubmission.WithRequestTxsFunc( - n.txsubmissionClientRequestTxs, - ), - */ + txsubmission.WithRequestTxIdsFunc(n.txsubmissionClientRequestTxIds), + txsubmission.WithRequestTxsFunc(n.txsubmissionClientRequestTxs), } } +func (n *Node) txsubmissionClientStart(connId ouroboros.ConnectionId) error { + // Register mempool consumer + // We don't bother capturing the consumer because we can easily look it up later by connection ID + _ = n.mempool.AddConsumer(connId) + // Start TxSubmission loop + conn := n.connManager.GetConnectionById(connId) + if conn == nil { + return fmt.Errorf("failed to lookup connection ID: %s", connId.String()) + } + oConn := conn.Conn + oConn.TxSubmission().Client.Init() + return nil +} + func (n *Node) txsubmissionServerInit(ctx otxsubmission.CallbackContext) error { // Start async loop to request transactions from the peer's mempool go func() { @@ -75,11 +88,106 @@ func (n *Node) txsubmissionServerInit(ctx otxsubmission.CallbackContext) error { n.config.logger.Error(fmt.Sprintf("failed to parse transaction CBOR: %s", err)) return } - n.config.logger.Debug(fmt.Sprintf("received TX %s via TxSubmission", tx.Hash())) - // TODO: add hooks to do something with TX + n.config.logger.Debug( + "received TX via TxSubmission", + slog.String("tx_hash", tx.Hash()), + slog.String("connection_id", ctx.ConnectionId.String()), + ) + // Add transaction to mempool + err = n.mempool.AddTransaction( + mempool.MempoolTransaction{ + Hash: tx.Hash(), + Type: uint(txBody.EraId), + Cbor: txBody.TxBody, + LastSeen: time.Now(), + }, + ) + if err != nil { + n.config.logger.Error( + fmt.Sprintf("failed to add TX %s to mempool: %s", tx.Hash(), err), + ) + return + } } } } }() return nil } + +func (n *Node) txsubmissionClientRequestTxIds( + ctx txsubmission.CallbackContext, + blocking bool, + ack uint16, + req uint16, +) ([]txsubmission.TxIdAndSize, error) { + connId := ctx.ConnectionId + ret := []txsubmission.TxIdAndSize{} + consumer := n.mempool.Consumer(connId) + // Clear TX cache + if ack > 0 { + consumer.ClearCache() + } + // Get available TXs + var tmpTxs []*mempool.MempoolTransaction + for { + if blocking && len(tmpTxs) == 0 { + // Wait until we see a TX + tmpTx := consumer.NextTx(true) + if tmpTx == nil { + break + } + tmpTxs = append(tmpTxs, tmpTx) + } else { + // Return immediately if no TX is available + tmpTx := consumer.NextTx(false) + if tmpTx == nil { + break + } + tmpTxs = append(tmpTxs, tmpTx) + } + } + for _, tmpTx := range tmpTxs { + tmpTx := tmpTx + // Add to return value + txHashBytes, err := hex.DecodeString(tmpTx.Hash) + if err != nil { + return nil, err + } + ret = append( + ret, + txsubmission.TxIdAndSize{ + TxId: txsubmission.TxId{ + EraId: uint16(tmpTx.Type), + TxId: [32]byte(txHashBytes), + }, + Size: uint32(len(tmpTx.Cbor)), + }, + ) + } + return ret, nil +} + +func (n *Node) txsubmissionClientRequestTxs( + ctx txsubmission.CallbackContext, + txIds []txsubmission.TxId, +) ([]txsubmission.TxBody, error) { + connId := ctx.ConnectionId + ret := []txsubmission.TxBody{} + consumer := n.mempool.Consumer(connId) + for _, txId := range txIds { + txHash := hex.EncodeToString(txId.TxId[:]) + tx := consumer.GetTxFromCache(txHash) + if tx != nil { + ret = append( + ret, + txsubmission.TxBody{ + EraId: uint16(tx.Type), + TxBody: tx.Cbor, + }, + ) + } + consumer.RemoveTxFromCache(txHash) + } + return ret, nil +}