diff --git a/gold/p2p/tcp/pdu.go b/gold/p2p/tcp/pdu.go index fa4d23f..54933fb 100644 --- a/gold/p2p/tcp/pdu.go +++ b/gold/p2p/tcp/pdu.go @@ -5,8 +5,10 @@ import ( "errors" "io" "net" + "time" "github.com/fumiama/WireGold/helper" + "github.com/sirupsen/logrus" ) var ( @@ -86,3 +88,39 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) { defer cl() return io.Copy(w, &buf) } + +func isvalid(tcpconn *net.TCPConn) bool { + pckt := packet{} + + stopch := make(chan struct{}) + t := time.AfterFunc(time.Second, func() { + stopch <- struct{}{} + }) + + var err error + copych := make(chan struct{}) + go func() { + _, err = io.Copy(&pckt, tcpconn) + copych <- struct{}{} + }() + + select { + case <-stopch: + logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout") + return false + case <-copych: + t.Stop() + } + + if err != nil { + logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err) + return false + } + if pckt.typ != packetTypeKeepAlive { + logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr()) + return false + } + + logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr()) + return true +} diff --git a/gold/p2p/tcp/tcp.go b/gold/p2p/tcp/tcp.go index 08f0691..c028c7a 100644 --- a/gold/p2p/tcp/tcp.go +++ b/gold/p2p/tcp/tcp.go @@ -50,8 +50,8 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) { } ep.addr = lstn.Addr().(*net.TCPAddr) peerstimeout := ep.peerstimeout - if peerstimeout < time.Second { - peerstimeout = time.Second * 5 + if peerstimeout < time.Second*30 { + peerstimeout = time.Second * 30 } chansz := ep.recvchansize if chansz < 32 { @@ -112,21 +112,28 @@ func (conn *Conn) accept() { logrus.Info("[tcp] re-listen on", conn.addr) continue } - ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{ - DialTimeout: conn.addr.dialtimeout, - PeersTimeout: conn.addr.peerstimeout, - ReceiveChannelSize: conn.addr.recvchansize, - }) + go conn.receive(tcpconn, false) + } +} + +func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) { + ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{ + DialTimeout: conn.addr.dialtimeout, + PeersTimeout: conn.addr.peerstimeout, + ReceiveChannelSize: conn.addr.recvchansize, + }) + + if !hasvalidated { + if !isvalid(tcpconn) { + return + } logrus.Debugln("[tcp] accept from", ep) conn.peers.Set(ep.String(), tcpconn) - go conn.receive(ep) } -} -func (conn *Conn) receive(ep *EndPoint) { - peerstimeout := ep.peerstimeout - if peerstimeout < time.Second { - peerstimeout = time.Second * 5 + peerstimeout := conn.addr.peerstimeout + if peerstimeout < time.Second*30 { + peerstimeout = time.Second * 30 } peerstimeout *= 2 for { @@ -244,9 +251,16 @@ func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) { if !ok { return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String()) } + _, err = io.Copy(tcpconn, &packet{ + typ: packetTypeKeepAlive, + }) + if err != nil { + logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err) + return 0, err + } logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr()) conn.peers.Set(tcpep.String(), tcpconn) - go conn.receive(tcpep) + go conn.receive(tcpconn, true) } else { logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr()) }