Skip to content

Commit

Permalink
feat: TX mempool and TxSubmission client (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
agaffney committed Apr 14, 2024
1 parent ff3ebf2 commit 86c797e
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 22 deletions.
97 changes: 97 additions & 0 deletions mempool/consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2024 Blink Labs Software
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mempool

import (
"sync"
)

type MempoolConsumer struct {
txChan chan *MempoolTransaction
cache map[string]*MempoolTransaction
cacheMutex sync.Mutex
}

func newConsumer() *MempoolConsumer {
return &MempoolConsumer{
txChan: make(chan *MempoolTransaction),
cache: make(map[string]*MempoolTransaction),
}
}

func (m *MempoolConsumer) NextTx(blocking bool) *MempoolTransaction {
var ret *MempoolTransaction
if blocking {
// Wait until a transaction is available
tmpTx, ok := <-m.txChan
if ok {
ret = tmpTx
}
} else {
select {
case tmpTx, ok := <-m.txChan:
if ok {
ret = tmpTx
}
default:
// No transaction available
}
}
if ret != nil {
// Add transaction to cache
m.cacheMutex.Lock()
m.cache[ret.Hash] = ret
m.cacheMutex.Unlock()
}
return ret
}

func (m *MempoolConsumer) GetTxFromCache(hash string) *MempoolTransaction {
m.cacheMutex.Lock()
defer m.cacheMutex.Unlock()
return m.cache[hash]
}

func (m *MempoolConsumer) ClearCache() {
m.cacheMutex.Lock()
defer m.cacheMutex.Unlock()
m.cache = make(map[string]*MempoolTransaction)
}

func (m *MempoolConsumer) RemoveTxFromCache(hash string) {
m.cacheMutex.Lock()
defer m.cacheMutex.Unlock()
delete(m.cache, hash)
}

func (m *MempoolConsumer) stop() {
close(m.txChan)
}

func (m *MempoolConsumer) pushTx(tx *MempoolTransaction, wait bool) bool {
if wait {
// Block on write to channel
m.txChan <- tx
return true
} else {
// Return immediately if we can't write to channel
select {
case m.txChan <- tx:
return true
default:
return false
}
}
}
225 changes: 225 additions & 0 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// Copyright 2024 Blink Labs Software
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mempool

import (
"fmt"
"io"
"log/slog"
"slices"
"sync"
"time"

ouroboros "github.com/blinklabs-io/gouroboros"
)

const (
txsubmissionMempoolExpiration = 1 * time.Hour
txSubmissionMempoolExpirationPeriod = 1 * time.Minute
)

type MempoolTransaction struct {
Hash string
Type uint
Cbor []byte
LastSeen time.Time
}

type Mempool struct {
sync.Mutex
logger *slog.Logger
consumers map[ouroboros.ConnectionId]*MempoolConsumer
consumersMutex sync.Mutex
consumerIndex map[ouroboros.ConnectionId]int
consumerIndexMutex sync.Mutex
transactions []*MempoolTransaction
}

func NewMempool(logger *slog.Logger) *Mempool {
m := &Mempool{
consumers: make(map[ouroboros.ConnectionId]*MempoolConsumer),
}
if logger == nil {
// Create logger to throw away logs
// We do this so we don't have to add guards around every log operation
m.logger = slog.New(slog.NewJSONHandler(io.Discard, nil))
}
// TODO: replace this with purging based on on-chain TXs
// Schedule initial mempool expired cleanup
m.scheduleRemoveExpired()
return m
}

func (m *Mempool) AddConsumer(connId ouroboros.ConnectionId) *MempoolConsumer {
// Create consumer
m.consumersMutex.Lock()
defer m.consumersMutex.Unlock()
consumer := newConsumer()
m.consumers[connId] = consumer
// Start goroutine to send existing TXs to consumer
go func(consumer *MempoolConsumer) {
for {
m.Lock()
m.consumerIndexMutex.Lock()
nextTxIdx, ok := m.consumerIndex[connId]
if !ok {
// Our consumer has disappeared
return
}
if nextTxIdx >= len(m.transactions) {
// We've reached the current end of the mempool
return
}
nextTx := m.transactions[nextTxIdx]
if consumer.pushTx(nextTx, true) {
nextTxIdx++
m.consumerIndex[connId] = nextTxIdx
}
m.consumerIndexMutex.Unlock()
m.Unlock()
}
}(consumer)
return consumer
}

func (m *Mempool) RemoveConsumer(connId ouroboros.ConnectionId) {
m.consumersMutex.Lock()
m.consumerIndexMutex.Lock()
defer func() {
m.consumerIndexMutex.Unlock()
m.consumersMutex.Unlock()
}()
if consumer, ok := m.consumers[connId]; ok {
consumer.stop()
delete(m.consumers, connId)
delete(m.consumerIndex, connId)
}
}

func (m *Mempool) Consumer(connId ouroboros.ConnectionId) *MempoolConsumer {
m.consumersMutex.Lock()
defer m.consumersMutex.Unlock()
return m.consumers[connId]
}

// TODO: replace this with purging based on on-chain TXs
func (m *Mempool) removeExpired() {
m.Lock()
defer m.Unlock()
expiredBefore := time.Now().Add(-txsubmissionMempoolExpiration)
for _, tx := range m.transactions {
if tx.LastSeen.Before(expiredBefore) {
m.removeTransaction(tx.Hash)
m.logger.Debug(
fmt.Sprintf("removed expired transaction %s from mempool", tx.Hash),
)
}
}
m.scheduleRemoveExpired()
}

func (m *Mempool) scheduleRemoveExpired() {
_ = time.AfterFunc(txSubmissionMempoolExpirationPeriod, m.removeExpired)
}

func (m *Mempool) AddTransaction(tx MempoolTransaction) error {
m.Lock()
m.consumersMutex.Lock()
m.consumerIndexMutex.Lock()
defer func() {
m.consumerIndexMutex.Unlock()
m.consumersMutex.Unlock()
m.Unlock()
}()
// Update last seen for existing TX
existingTx := m.getTransaction(tx.Hash)
if existingTx != nil {
tx.LastSeen = time.Now()
m.logger.Debug(
fmt.Sprintf("updated last seen for transaction %s in mempool", tx.Hash),
)
return nil
}
// Add transaction record
m.transactions = append(m.transactions, &tx)
m.logger.Debug(
fmt.Sprintf("added transaction %s to mempool", tx.Hash),
)
// Send new TX to consumers that are ready for it
newTxIdx := len(m.transactions) - 1
for connId, consumerIdx := range m.consumerIndex {
if consumerIdx == newTxIdx {
consumer := m.consumers[connId]
if consumer.pushTx(&tx, false) {
consumerIdx++
m.consumerIndex[connId] = consumerIdx
}
}
}
return nil
}

func (m *Mempool) GetTransaction(txHash string) (MempoolTransaction, bool) {
m.Lock()
defer m.Unlock()
ret := m.getTransaction(txHash)
if ret == nil {
return MempoolTransaction{}, false
}
return *ret, true
}

func (m *Mempool) getTransaction(txHash string) *MempoolTransaction {
for _, tx := range m.transactions {
if tx.Hash == txHash {
return tx
}
}
return nil
}

func (m *Mempool) RemoveTransaction(hash string) {
m.Lock()
defer m.Unlock()
if m.removeTransaction(hash) {
m.logger.Debug(
fmt.Sprintf("removed transaction %s from mempool", hash),
)
}
}

func (m *Mempool) removeTransaction(hash string) bool {
for txIdx, tx := range m.transactions {
if tx.Hash == hash {
m.consumerIndexMutex.Lock()
m.transactions = slices.Delete(
m.transactions,
txIdx,
txIdx+1,
)
// Update consumer indexes to reflect removed TX
for connId, consumerIdx := range m.consumerIndex {
// Decrement consumer index if the consumer has reached the removed TX
if consumerIdx >= txIdx {
consumerIdx--
}
m.consumerIndex[connId] = consumerIdx
}
m.consumerIndexMutex.Unlock()
return true
}
}
return false
}
6 changes: 5 additions & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"sync"

"github.com/blinklabs-io/node/chainsync"
"github.com/blinklabs-io/node/mempool"

ouroboros "github.com/blinklabs-io/gouroboros"
)
Expand All @@ -29,12 +30,14 @@ type Node struct {
chainsyncState *chainsync.State
outboundConns map[ouroboros.ConnectionId]outboundPeer
outboundConnsMutex sync.Mutex
mempool *mempool.Mempool
}

func New(cfg Config) (*Node, error) {
n := &Node{
config: cfg,
chainsyncState: chainsync.NewState(),
mempool: mempool.NewMempool(cfg.logger),
outboundConns: make(map[ouroboros.ConnectionId]outboundPeer),
}
if err := n.configPopulateNetworkMagic(); err != nil {
Expand Down Expand Up @@ -64,7 +67,6 @@ func (n *Node) Run() error {
n.connManager.AddHostsFromTopology(n.config.topologyConfig)
}
n.startOutboundConnections()
// TODO

// Wait forever
select {}
Expand All @@ -84,6 +86,8 @@ func (n *Node) connectionManagerConnClosed(connId ouroboros.ConnectionId, err er
n.connManager.RemoveConnection(connId)
// Remove any chainsync client state
n.chainsyncState.RemoveClient(connId)
// Remove mempool consumer
n.mempool.RemoveConsumer(connId)
// Outbound connections
n.outboundConnsMutex.Lock()
if peer, ok := n.outboundConns[connId]; ok {
Expand Down
14 changes: 4 additions & 10 deletions outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,6 @@ func (n *Node) startOutboundConnections() {

}

/*
func (n *Node) getOutboundConn(connId ouroboros.ConnectionId) (outboundPeer, bool) {
n.outboundConnsMutex.Lock()
defer n.outboundConnsMutex.Unlock()
conn, ok := n.outboundConns[connId]
return conn, ok
}
*/

func (n *Node) createOutboundConnection(peer outboundPeer) error {
var clientAddr net.Addr
dialer := net.Dialer{
Expand Down Expand Up @@ -161,7 +152,10 @@ func (n *Node) createOutboundConnection(peer outboundPeer) error {
return err
}
}
// TODO: start txsubmission client
// Start txsubmission client
if err := n.txsubmissionClientStart(oConn.Id()); err != nil {
return err
}
return nil
}

Expand Down
Loading

0 comments on commit 86c797e

Please sign in to comment.