Skip to content

Commit

Permalink
🏒 ss2022, zerocopy: improve stream relay
Browse files Browse the repository at this point in the history
  • Loading branch information
database64128 committed Sep 20, 2024
1 parent 0d41eb3 commit 993a982
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
13 changes: 8 additions & 5 deletions ss2022/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC
s := NewTCPServer(false, userCipherConfig, identityCipherConfig, nil, nil)
s.ReplaceUserLookupMap(userLookupMap)

var cerr, serr error
ctrlCh := make(chan struct{})
var (
wg sync.WaitGroup
cerr error
)

// Start client.
wg.Add(1)
go func() {
_, _, cerr = c.Dial(ctx, clientTargetAddr, nil)
close(ctrlCh)
wg.Done()
}()

// Hijack client request and save it in b.
Expand All @@ -110,7 +113,7 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC
}

// Ensure client success.
<-ctrlCh
wg.Wait()
if cerr != nil {
t.Fatal(cerr)
}
Expand All @@ -119,7 +122,7 @@ func testShadowStreamReadWriterReplay(t *testing.T, ctx context.Context, clientC
go sendFunc()

// Start server.
_, _, _, _, serr = s.Accept(pr)
_, _, _, _, serr := s.Accept(pr)
if serr != nil {
t.Fatal(serr)
}
Expand Down
41 changes: 21 additions & 20 deletions zerocopy/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zerocopy
import (
"bytes"
"context"
"errors"
"io"
"sync"
)
Expand Down Expand Up @@ -218,23 +219,23 @@ type ReadWriter interface {
// It returns the number of bytes sent from left to right, from right to left,
// and any error occurred during transfer.
func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) {
var l2rErr error
l2rDone := make(chan struct{})
var (
wg sync.WaitGroup
l2rErr error
)

wg.Add(1)
go func() {
nl2r, l2rErr = Relay(right, left)
right.CloseWrite()
close(l2rDone)
_ = right.CloseWrite()
wg.Done()
}()

nr2l, err = Relay(left, right)
left.CloseWrite()
<-l2rDone
_ = left.CloseWrite()
wg.Wait()

if l2rErr != nil {
err = l2rErr
}
return
return nl2r, nr2l, errors.Join(l2rErr, err)
}

// DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite.
Expand All @@ -248,23 +249,23 @@ type DirectReadWriteCloser interface {
// It returns the number of bytes sent from left to right, from right to left,
// and any error occurred during transfer.
func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) {
var l2rErr error
l2rDone := make(chan struct{})
var (
wg sync.WaitGroup
l2rErr error
)

wg.Add(1)
go func() {
nl2r, l2rErr = io.Copy(right, left)
right.CloseWrite()
close(l2rDone)
_ = right.CloseWrite()
wg.Done()
}()

nr2l, err = io.Copy(left, right)
left.CloseWrite()
<-l2rDone
_ = left.CloseWrite()
wg.Wait()

if l2rErr != nil {
err = l2rErr
}
return
return nl2r, nr2l, errors.Join(l2rErr, err)
}

// DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser].
Expand Down

0 comments on commit 993a982

Please sign in to comment.