From fe588a82b8417001fd479906e36ae83a13d2db02 Mon Sep 17 00:00:00 2001 From: Felicitas Pojtinger Date: Sat, 14 Oct 2023 02:32:47 +0200 Subject: [PATCH 1/4] fix: Call `ConnContext` before reading first packet --- server.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 376883d..2c86b89 100644 --- a/server.go +++ b/server.go @@ -309,9 +309,6 @@ func (s *Server) handleSinglePacket(frisbeeConn *Async, connCtx context.Context) s.onClosed(frisbeeConn, err) return } - if s.ConnContext != nil { - connCtx = s.ConnContext(connCtx, frisbeeConn) - } for { handlerFunc = s.handlerTable[p.Metadata.Operation] if handlerFunc != nil { @@ -361,9 +358,6 @@ func (s *Server) handleUnlimitedPacket(frisbeeConn *Async, connCtx context.Conte s.onClosed(frisbeeConn, err) return } - if s.ConnContext != nil { - connCtx = s.ConnContext(connCtx, frisbeeConn) - } wg := new(sync.WaitGroup) closed := atomic.NewBool(false) connCtx, cancel := context.WithCancel(connCtx) @@ -391,9 +385,6 @@ func (s *Server) handleLimitedPacket(frisbeeConn *Async, connCtx context.Context s.onClosed(frisbeeConn, err) return } - if s.ConnContext != nil { - connCtx = s.ConnContext(connCtx, frisbeeConn) - } wg := new(sync.WaitGroup) closed := atomic.NewBool(false) connCtx, cancel := context.WithCancel(connCtx) @@ -465,6 +456,9 @@ func (s *Server) serveConn(newConn net.Conn) { } s.connections[frisbeeConn] = struct{}{} s.connectionsMu.Unlock() + if s.ConnContext != nil { + connCtx = s.ConnContext(connCtx, frisbeeConn) + } if s.concurrency == 0 { s.handleUnlimitedPacket(frisbeeConn, connCtx) } else if s.concurrency == 1 { From 48086fa27aaa742219c8a57b9e766ef7a1c5d247 Mon Sep 17 00:00:00 2001 From: Shivansh Vij Date: Tue, 26 Mar 2024 14:12:57 -0700 Subject: [PATCH 2/4] Updating test specification to improve speeds --- .github/workflows/tests.yaml | 24 ++++++++++++------------ server.go | 7 ++++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6ff94d2..43c213f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -7,11 +7,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version: "1.22" check-latest: true cache: true - name: Run Tests @@ -20,11 +20,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version: "1.22" check-latest: true cache: true - name: Test with Race Conditions @@ -34,11 +34,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version: "1.22" check-latest: true cache: true - name: Benchmark @@ -47,11 +47,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version: "1.22" check-latest: true cache: true - name: Benchmark with Race Conditions diff --git a/server.go b/server.go index 79ee05a..d3cc969 100644 --- a/server.go +++ b/server.go @@ -459,11 +459,12 @@ func (s *Server) serveConn(newConn net.Conn) { if s.ConnContext != nil { connCtx = s.ConnContext(connCtx, frisbeeConn) } - if s.concurrency == 0 { + switch s.concurrency { + case 0: s.handleUnlimitedPacket(frisbeeConn, connCtx) - } else if s.concurrency == 1 { + case 1: s.handleSinglePacket(frisbeeConn, connCtx) - } else { + default: s.handleLimitedPacket(frisbeeConn, connCtx) } s.connectionsMu.Lock() From ae14faaee016294adb99a7bade8a1d0cc58ebe43 Mon Sep 17 00:00:00 2001 From: Shivansh Vij Date: Tue, 26 Mar 2024 14:24:58 -0700 Subject: [PATCH 3/4] Sleeping for flaky tests --- async_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/async_test.go b/async_test.go index 12923a6..0f2f941 100644 --- a/async_test.go +++ b/async_test.go @@ -379,6 +379,7 @@ func TestAsyncTimeout(t *testing.T) { assert.Equal(t, uint32(0), p.Metadata.ContentLength) assert.Equal(t, 0, len(*p.Content)) + t.Logf("Sleeping for %v\n", DefaultDeadline*2) time.Sleep(DefaultDeadline * 2) err = writerConn.Error() @@ -392,12 +393,14 @@ func TestAsyncTimeout(t *testing.T) { packet.Put(p) + t.Logf("Sleeping for %v\n", DefaultDeadline) time.Sleep(DefaultDeadline) require.Equal(t, 1, readerConn.incoming.Length()) err = writerConn.conn.Close() require.NoError(t, err) + t.Logf("Sleeping for %v\n", DefaultDeadline*2) runtime.Gosched() time.Sleep(DefaultDeadline * 2) runtime.Gosched() @@ -415,6 +418,7 @@ func TestAsyncTimeout(t *testing.T) { err = readerConn.Error() if err == nil { + t.Logf("Sleeping for %v\n", DefaultDeadline*3) runtime.Gosched() time.Sleep(DefaultDeadline * 3) runtime.Gosched() From b1aa650dc7da276ceac908a484aefb650aadb033 Mon Sep 17 00:00:00 2001 From: Shivansh Vij Date: Tue, 26 Mar 2024 14:45:40 -0700 Subject: [PATCH 4/4] Fixing race condition where connection close would cause an infinite hang --- async.go | 43 ++++++++++++++++++++++++++----------------- async_test.go | 4 ---- stream.go | 4 ++-- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/async.go b/async.go index 170edef..c0f052a 100644 --- a/async.go +++ b/async.go @@ -183,7 +183,7 @@ func (c *Async) WritePacket(p *packet.Packet) error { if p.Metadata.Operation <= RESERVED9 { return InvalidOperation } - return c.writePacket(p) + return c.writePacket(p, true) } // ReadPacket is a blocking function that will wait until a Frisbee packet is available and then return it (and its content). @@ -300,7 +300,7 @@ func (c *Async) Close() error { } // write packet is the internal write packet function that does not check for reserved operations. -func (c *Async) writePacket(p *packet.Packet) error { +func (c *Async) writePacket(p *packet.Packet, closeOnErr bool) error { if int(p.Metadata.ContentLength) != len(*p.Content) { return InvalidContentLength } @@ -323,7 +323,10 @@ func (c *Async) writePacket(p *packet.Packet) error { return ConnectionClosed } c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while setting write deadline before writing packet") - return c.closeWithError(err) + if closeOnErr { + return c.closeWithError(err) + } + return err } _, err = c.writer.Write(encodedMetadata[:]) metadata.PutBuffer(encodedMetadata) @@ -334,7 +337,10 @@ func (c *Async) writePacket(p *packet.Packet) error { return ConnectionClosed } c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") - return c.closeWithError(err) + if closeOnErr { + return c.closeWithError(err) + } + return err } if p.Metadata.ContentLength != 0 { _, err = c.writer.Write((*p.Content)[:p.Metadata.ContentLength]) @@ -345,7 +351,10 @@ func (c *Async) writePacket(p *packet.Packet) error { return ConnectionClosed } c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing packet content") - return c.closeWithError(err) + if closeOnErr { + return c.closeWithError(err) + } + return err } } @@ -457,7 +466,7 @@ func (c *Async) pingLoop() { c.wg.Done() return case <-ticker.C: - err = c.writePacket(PINGPacket) + err = c.writePacket(PINGPacket, false) if err != nil { c.wg.Done() _ = c.closeWithError(err) @@ -516,7 +525,7 @@ func (c *Async) readLoop() { switch p.Metadata.Operation { case PING: c.Logger().Debug().Msg("PING Packet received by read loop, sending back PONG packet") - err = c.writePacket(PONGPacket) + err = c.writePacket(PONGPacket, false) if err != nil { c.wg.Done() _ = c.closeWithError(err) @@ -541,13 +550,13 @@ func (c *Async) readLoop() { default: if p.Metadata.ContentLength > 0 { if n-index < int(p.Metadata.ContentLength) { - min := int(p.Metadata.ContentLength) - p.Content.Write(buf[index:n]) + minSize := int(p.Metadata.ContentLength) - p.Content.Write(buf[index:n]) n = 0 - for cap(buf) < min { + for cap(buf) < minSize { buf = append(buf[:cap(buf)], 0) } buf = buf[:cap(buf)] - for n < min { + for n < minSize { var nn int err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) if err != nil { @@ -558,7 +567,7 @@ func (c *Async) readLoop() { nn, err = c.conn.Read(buf[n:]) n += nn if err != nil { - if n < min { + if n < minSize { c.wg.Done() _ = c.closeWithError(err) return @@ -566,8 +575,8 @@ func (c *Async) readLoop() { break } } - p.Content.Write(buf[:min]) - index = min + p.Content.Write(buf[:minSize]) + index = minSize } else { index += p.Content.Write(buf[index : index+int(p.Metadata.ContentLength)]) } @@ -649,14 +658,14 @@ func (c *Async) readLoop() { index = n buf = buf[:cap(buf)] - min := metadata.Size - index - if len(buf) < min { + minSize := metadata.Size - index + if len(buf) < minSize { c.wg.Done() _ = c.closeWithError(InvalidBufferLength) return } n = 0 - for n < min { + for n < minSize { var nn int err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) if err != nil { @@ -667,7 +676,7 @@ func (c *Async) readLoop() { nn, err = c.conn.Read(buf[index+n:]) n += nn if err != nil { - if n < min { + if n < minSize { c.wg.Done() _ = c.closeWithError(err) return diff --git a/async_test.go b/async_test.go index 0f2f941..12923a6 100644 --- a/async_test.go +++ b/async_test.go @@ -379,7 +379,6 @@ func TestAsyncTimeout(t *testing.T) { assert.Equal(t, uint32(0), p.Metadata.ContentLength) assert.Equal(t, 0, len(*p.Content)) - t.Logf("Sleeping for %v\n", DefaultDeadline*2) time.Sleep(DefaultDeadline * 2) err = writerConn.Error() @@ -393,14 +392,12 @@ func TestAsyncTimeout(t *testing.T) { packet.Put(p) - t.Logf("Sleeping for %v\n", DefaultDeadline) time.Sleep(DefaultDeadline) require.Equal(t, 1, readerConn.incoming.Length()) err = writerConn.conn.Close() require.NoError(t, err) - t.Logf("Sleeping for %v\n", DefaultDeadline*2) runtime.Gosched() time.Sleep(DefaultDeadline * 2) runtime.Gosched() @@ -418,7 +415,6 @@ func TestAsyncTimeout(t *testing.T) { err = readerConn.Error() if err == nil { - t.Logf("Sleeping for %v\n", DefaultDeadline*3) runtime.Gosched() time.Sleep(DefaultDeadline * 3) runtime.Gosched() diff --git a/stream.go b/stream.go index dd380e5..7aedc77 100644 --- a/stream.go +++ b/stream.go @@ -92,7 +92,7 @@ func (s *Stream) WritePacket(p *packet.Packet) error { } p.Metadata.Id = s.id p.Metadata.Operation = STREAM - return s.conn.writePacket(p) + return s.conn.writePacket(p, true) } // ID returns the stream's ID. @@ -116,7 +116,7 @@ func (s *Stream) Close() error { p := packet.Get() p.Metadata.Id = s.id p.Metadata.Operation = STREAM - err := s.conn.writePacket(p) + err := s.conn.writePacket(p, true) packet.Put(p) s.conn.streamsMu.Lock()