From 993a982edc004d61e6efd95aab78555cd58064f1 Mon Sep 17 00:00:00 2001 From: database64128 Date: Fri, 20 Sep 2024 16:21:42 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=92=20ss2022,=20zerocopy:=20improve=20?= =?UTF-8?q?stream=20relay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ss2022/stream_test.go | 13 ++++++++----- zerocopy/stream.go | 41 +++++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/ss2022/stream_test.go b/ss2022/stream_test.go index 6ce9afd..4e4395a 100644 --- a/ss2022/stream_test.go +++ b/ss2022/stream_test.go @@ -87,13 +87,16 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC s := NewTCPServer(false, userCipherConfig, identityCipherConfig, nil, nil) s.ReplaceUserLookupMap(userLookupMap) - var cerr, serr error - ctrlCh := make(chan struct{}) + var ( + wg sync.WaitGroup + cerr error + ) // Start client. + wg.Add(1) go func() { _, _, cerr = c.Dial(ctx, clientTargetAddr, nil) - close(ctrlCh) + wg.Done() }() // Hijack client request and save it in b. @@ -110,7 +113,7 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC } // Ensure client success. - <-ctrlCh + wg.Wait() if cerr != nil { t.Fatal(cerr) } @@ -119,7 +122,7 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC go sendFunc() // Start server. - _, _, _, _, serr = s.Accept(pr) + _, _, _, _, serr := s.Accept(pr) if serr != nil { t.Fatal(serr) } diff --git a/zerocopy/stream.go b/zerocopy/stream.go index 3488f09..85a8ae0 100644 --- a/zerocopy/stream.go +++ b/zerocopy/stream.go @@ -3,6 +3,7 @@ package zerocopy import ( "bytes" "context" + "errors" "io" "sync" ) @@ -218,23 +219,23 @@ type ReadWriter interface { // It returns the number of bytes sent from left to right, from right to left, // and any error occurred during transfer. func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) { - var l2rErr error - l2rDone := make(chan struct{}) + var ( + wg sync.WaitGroup + l2rErr error + ) + wg.Add(1) go func() { nl2r, l2rErr = Relay(right, left) - right.CloseWrite() - close(l2rDone) + _ = right.CloseWrite() + wg.Done() }() nr2l, err = Relay(left, right) - left.CloseWrite() - <-l2rDone + _ = left.CloseWrite() + wg.Wait() - if l2rErr != nil { - err = l2rErr - } - return + return nl2r, nr2l, errors.Join(l2rErr, err) } // DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite. @@ -248,23 +249,23 @@ type DirectReadWriteCloser interface { // It returns the number of bytes sent from left to right, from right to left, // and any error occurred during transfer. func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) { - var l2rErr error - l2rDone := make(chan struct{}) + var ( + wg sync.WaitGroup + l2rErr error + ) + wg.Add(1) go func() { nl2r, l2rErr = io.Copy(right, left) - right.CloseWrite() - close(l2rDone) + _ = right.CloseWrite() + wg.Done() }() nr2l, err = io.Copy(left, right) - left.CloseWrite() - <-l2rDone + _ = left.CloseWrite() + wg.Wait() - if l2rErr != nil { - err = l2rErr - } - return + return nl2r, nr2l, errors.Join(l2rErr, err) } // DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser].