Skip to content

Commit

Permalink
feat(tcp): validate conn on accept
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 16, 2024
1 parent 8fa23be commit 4ffacaf
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 14 deletions.
38 changes: 38 additions & 0 deletions gold/p2p/tcp/pdu.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"errors"
"io"
"net"
"time"

"github.com/fumiama/WireGold/helper"
"github.com/sirupsen/logrus"
)

var (
Expand Down Expand Up @@ -86,3 +88,39 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
defer cl()
return io.Copy(w, &buf)
}

func isvalid(tcpconn *net.TCPConn) bool {
pckt := packet{}

stopch := make(chan struct{})
t := time.AfterFunc(time.Second, func() {
stopch <- struct{}{}
})

var err error
copych := make(chan struct{})
go func() {
_, err = io.Copy(&pckt, tcpconn)
copych <- struct{}{}
}()

select {
case <-stopch:
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
return false
case <-copych:
t.Stop()
}

if err != nil {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
return false
}
if pckt.typ != packetTypeKeepAlive {
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
return false
}

logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
return true
}
42 changes: 28 additions & 14 deletions gold/p2p/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) {
}
ep.addr = lstn.Addr().(*net.TCPAddr)
peerstimeout := ep.peerstimeout
if peerstimeout < time.Second {
peerstimeout = time.Second * 5
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
chansz := ep.recvchansize
if chansz < 32 {
Expand Down Expand Up @@ -112,21 +112,28 @@ func (conn *Conn) accept() {
logrus.Info("[tcp] re-listen on", conn.addr)
continue
}
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})
go conn.receive(tcpconn, false)
}
}

func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})

if !hasvalidated {
if !isvalid(tcpconn) {
return
}
logrus.Debugln("[tcp] accept from", ep)
conn.peers.Set(ep.String(), tcpconn)
go conn.receive(ep)
}
}

func (conn *Conn) receive(ep *EndPoint) {
peerstimeout := ep.peerstimeout
if peerstimeout < time.Second {
peerstimeout = time.Second * 5
peerstimeout := conn.addr.peerstimeout
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
peerstimeout *= 2
for {
Expand Down Expand Up @@ -244,9 +251,16 @@ func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
if !ok {
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
}
_, err = io.Copy(tcpconn, &packet{
typ: packetTypeKeepAlive,
})
if err != nil {
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err)
return 0, err
}
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
conn.peers.Set(tcpep.String(), tcpconn)
go conn.receive(tcpep)
go conn.receive(tcpconn, true)
} else {
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
}
Expand Down

0 comments on commit 4ffacaf

Please sign in to comment.