Skip to content

Commit

Permalink
🦩 zerocopy: split and rewrite CopyReadWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
database64128 committed Feb 19, 2025
1 parent 28ff236 commit 5db1ab6
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 92 deletions.
6 changes: 3 additions & 3 deletions dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,12 @@ func (r *Resolver) sendQueriesTCP(ctx context.Context, nameString string, querie
}

// Read.
crw := zerocopy.NewCopyReadWriter(rw)
cr := zerocopy.NewCopyReader(rw)
lengthBuf := make([]byte, 2)

for range 2 {
// Read length field.
_, err = io.ReadFull(crw, lengthBuf)
_, err = io.ReadFull(cr, lengthBuf)
if err != nil {
r.logger.Warn("Failed to read TCP DNS response length",
zap.String("resolver", r.name),
Expand All @@ -550,7 +550,7 @@ func (r *Resolver) sendQueriesTCP(ctx context.Context, nameString string, querie

// Read message.
msg := make([]byte, msgLen)
_, err = io.ReadFull(crw, msg)
_, err = io.ReadFull(cr, msg)
if err != nil {
r.logger.Warn("Failed to read TCP DNS response",
zap.String("resolver", r.name),
Expand Down
4 changes: 2 additions & 2 deletions ss2022/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func testShadowStreamReadWriter(t *testing.T, allowSegmentedFixedLengthHeader bo
// Read excess payload.
b := make([]byte, len(clientInitialPayload))
copy(b, serverInitialPayload)
scrw := zerocopy.NewCopyReadWriter(srw)
_, serr = io.ReadFull(scrw, b[len(serverInitialPayload):])
scr := zerocopy.NewCopyReader(srw)
_, serr = io.ReadFull(scr, b[len(serverInitialPayload):])
serverInitialPayload = b
}
wg.Done()
Expand Down
192 changes: 105 additions & 87 deletions zerocopy/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,122 +451,140 @@ func ReadWriterTestFunc(t tester, l, r ReadWriter) {
wg.Wait()
}

// CopyReadWriter wraps a ReadWriter and provides the io.ReadWriter Read and Write methods
// by copying from and to internal buffers and using the zerocopy methods on them.
//
// The io.ReaderFrom ReadFrom method is implemented using the internal write buffer without copying.
type CopyReadWriter struct {
ReadWriter

readHeadroom Headroom
writeHeadroom Headroom

readBuf []byte
readBufStart int
readBufLength int

writeBuf []byte
// CopyReader implements [io.Reader] for a [Reader] by copying in and out of an internal buffer.
// [io.WriterTo] is implemented using the same buffer.
type CopyReader struct {
Reader
buf []byte
payloadBufStart int
payloadBufLen int
copyStart int
copyEnd int
}

func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter {
ri := rw.ReaderInfo()
wi := rw.WriterInfo()

readBufSize := ri.MinPayloadBufferSizePerRead
if readBufSize == 0 {
readBufSize = defaultBufferSize
// NewCopyReader returns a new [CopyReader] for the given [Reader].
func NewCopyReader(r Reader) *CopyReader {
ri := r.ReaderInfo()
payloadBufLen := ri.MinPayloadBufferSizePerRead
if payloadBufLen == 0 {
payloadBufLen = defaultBufferSize
}

writeBufSize := wi.MaxPayloadSizePerWrite
if writeBufSize == 0 {
writeBufSize = defaultBufferSize
return &CopyReader{
Reader: r,
buf: make([]byte, ri.Headroom.Front+payloadBufLen+ri.Headroom.Rear),
payloadBufStart: ri.Headroom.Front,
payloadBufLen: payloadBufLen,
}
}

return &CopyReadWriter{
ReadWriter: rw,
readHeadroom: ri.Headroom,
writeHeadroom: wi.Headroom,
readBuf: make([]byte, ri.Headroom.Front+readBufSize+ri.Headroom.Rear),
writeBuf: make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear),
// Read implements [io.Reader].
func (r *CopyReader) Read(b []byte) (n int, err error) {
if r.copyStart == r.copyEnd {
n, err = r.Reader.ReadZeroCopy(r.buf, r.payloadBufStart, r.payloadBufLen)
if n == 0 {
return 0, err
}
r.copyStart = r.payloadBufStart
r.copyEnd = r.payloadBufStart + n
}
n = copy(b, r.buf[r.copyStart:r.copyEnd])
r.copyStart += n
return n, err
}

// Read implements the io.Reader Read method.
func (rw *CopyReadWriter) Read(b []byte) (n int, err error) {
if rw.readBufLength == 0 {
rw.readBufStart = rw.readHeadroom.Front
rw.readBufLength = len(rw.readBuf) - rw.readHeadroom.Front - rw.readHeadroom.Rear
rw.readBufLength, err = rw.ReadWriter.ReadZeroCopy(rw.readBuf, rw.readBufStart, rw.readBufLength)
if err != nil {
return
// WriteTo implements [io.WriterTo].
func (r *CopyReader) WriteTo(w io.Writer) (n int64, err error) {
for {
if r.copyStart == r.copyEnd {
pl, err := r.Reader.ReadZeroCopy(r.buf, r.payloadBufStart, r.payloadBufLen)
if pl == 0 {
if err == io.EOF {
return n, nil
}
return n, err
}
r.copyStart = r.payloadBufStart
r.copyEnd = r.payloadBufStart + pl
}

wn, werr := w.Write(r.buf[r.copyStart:r.copyEnd])
n += int64(wn)
r.copyStart += wn
if werr != nil {
return n, werr
}
}
}

n = copy(b, rw.readBuf[rw.readBufStart:rw.readBufStart+rw.readBufLength])
rw.readBufStart += n
rw.readBufLength -= n
return n, nil
// CopyWriter implements [io.Writer] for a [Writer] by copying in and out of an internal buffer.
// [io.ReaderFrom] is implemented using the same buffer.
type CopyWriter struct {
Writer
buf []byte
payloadStart int
payloadEnd int
}

// Write implements the io.Writer Write method.
func (rw *CopyReadWriter) Write(b []byte) (n int, err error) {
payloadBuf := rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear]
// NewCopyWriter returns a new [CopyWriter] for the given [Writer].
func NewCopyWriter(w Writer) *CopyWriter {
wi := w.WriterInfo()
payloadBufLen := wi.MaxPayloadSizePerWrite
if payloadBufLen == 0 {
payloadBufLen = defaultBufferSize
}
payloadEnd := wi.Headroom.Front + payloadBufLen
return &CopyWriter{
Writer: w,
buf: make([]byte, payloadEnd+wi.Headroom.Rear),
payloadStart: wi.Headroom.Front,
payloadEnd: payloadEnd,
}
}

// Write implements [io.Writer].
func (w *CopyWriter) Write(b []byte) (n int, err error) {
payloadBuf := w.buf[w.payloadStart:w.payloadEnd]
for n < len(b) {
payloadLength := copy(payloadBuf, b[n:])
var payloadWritten int
payloadWritten, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, payloadLength)
payloadLen := copy(payloadBuf, b[n:])
payloadWritten, err := w.Writer.WriteZeroCopy(w.buf, w.payloadStart, payloadLen)
n += payloadWritten
if err != nil {
return
return n, err
}
}

return
return n, nil
}

// ReadFrom implements the io.ReaderFrom ReadFrom method.
func (rw *CopyReadWriter) ReadFrom(r io.Reader) (n int64, err error) {
// ReadFrom implements [io.ReaderFrom].
func (w *CopyWriter) ReadFrom(r io.Reader) (n int64, err error) {
payloadBuf := w.buf[w.payloadStart:w.payloadEnd]
for {
nr, err := r.Read(rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear])
n += int64(nr)
switch err {
case nil:
case io.EOF:
return n, nil
default:
return n, err
nr, err := r.Read(payloadBuf)
if nr > 0 {
n += int64(nr)
if _, err := w.Writer.WriteZeroCopy(w.buf, w.payloadStart, nr); err != nil {
return n, err
}
}

_, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, nr)
if err != nil {
if err == io.EOF {
return n, nil
}
return n, err
}
}
}

func CopyWriteOnce(w Writer, b []byte) (n int, err error) {
wi := w.WriterInfo()
writeBufSize := wi.MaxPayloadSizePerWrite
if writeBufSize == 0 {
writeBufSize = defaultBufferSize
}
if writeBufSize > len(b) {
writeBufSize = len(b)
}

writeBuf := make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear)
payloadBuf := writeBuf[wi.Headroom.Front : wi.Headroom.Front+writeBufSize]
// CopyReadWriter joins a [CopyReader] and a [CopyWriter] into a single [ReadWriter].
type CopyReadWriter struct {
CopyReader
CopyWriter
}

for n < len(b) {
payloadLength := copy(payloadBuf, b[n:])
var payloadWritten int
payloadWritten, err = w.WriteZeroCopy(writeBuf, wi.Headroom.Front, payloadLength)
n += payloadWritten
if err != nil {
return
}
// NewCopyReadWriter returns a new [CopyReadWriter] for the given [ReadWriter].
func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter {
return &CopyReadWriter{
CopyReader: *NewCopyReader(rw),
CopyWriter: *NewCopyWriter(rw),
}

return
}

0 comments on commit 5db1ab6

Please sign in to comment.