Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 70 additions & 48 deletions brontide/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"math"
"net"
"sync"
"time"

Expand Down Expand Up @@ -821,71 +822,92 @@ func (b *Machine) WriteMessage(p []byte) error {
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (b *Machine) Flush(w io.Writer) (int, error) {
// First, write out the pending header bytes, if any exist. Any header
// bytes written will not count towards the total amount flushed.
if len(b.nextHeaderSend) > 0 {
// Write any remaining header bytes and shift the slice to point
// to the next segment of unwritten bytes. If an error is
// encountered, we can continue to write the header from where
// we left off on a subsequent call to Flush.
n, err := w.Write(b.nextHeaderSend)
b.nextHeaderSend = b.nextHeaderSend[n:]
if err != nil {
return 0, err
}
}

// Next, write the pending body bytes, if any exist. Only the number of
// bytes written that correspond to the ciphertext will be included in
// the total bytes written, bytes written as part of the MAC will not be
// counted.
// If there's nothing to flush, return early.
if len(b.nextHeaderSend) == 0 && len(b.nextBodySend) == 0 {
return 0, nil
}

// Record the initial header length so we can calculate how much of each
// buffer was written in case of a partial write.
headerLen := len(b.nextHeaderSend)

// Use net.Buffers to perform a vectored write (writev syscall) that
// combines both header and body into a single TCP segment. This avoids
// the inefficiency of sending two separate TCP packets for each
// brontide message. The WriteTo method will use writev on supported
// platforms, falling back to sequential writes otherwise.
buffers := net.Buffers{b.nextHeaderSend, b.nextBodySend}
n64, err := buffers.WriteTo(w)

n := int(n64)

// Calculate how much was written from header vs body. The total bytes
// written (n) spans both buffers sequentially.
var headerWritten, bodyWritten int
switch {
// A partial write within the header only, with none of the body being
// written.
case n <= headerLen:
headerWritten = n
bodyWritten = 0

// All of header written, plus some (or all) of body.
default:
headerWritten = headerLen
bodyWritten = n - headerLen
}

// Update the slice offsets to allow resuming from where we left off on
// a subsequent call to Flush in case of a timeout error.
b.nextHeaderSend = b.nextHeaderSend[headerWritten:]
b.nextBodySend = b.nextBodySend[bodyWritten:]

// Calculate the number of plaintext bytes flushed, excluding MAC bytes.
// Only bytes written from the body count towards the return value, and
// we must exclude any MAC bytes that were written.
//
// There are three possible scenarios we must handle to ensure the
// returned value is correct. In the first case, the write straddles
// both payload and MAC bytes, and we must subtract the number of MAC
// bytes written. In the second, only payload bytes are written, thus we
// return bodyWritten unmodified. The final scenario pertains to the
// case where only MAC bytes are written, none of which count towards
// the total.
//
// |-----------Payload------------|----MAC----|
// Straddle: S---------------------------------E--------0
// Payload-only: S------------------------E-----------------0
// MAC-only: S-------E-0
var nn int
if len(b.nextBodySend) > 0 {
// Write out all bytes excluding the mac and shift the body
// slice depending on the number of actual bytes written.
n, err := w.Write(b.nextBodySend)
b.nextBodySend = b.nextBodySend[n:]

// If we partially or fully wrote any of the body's MAC, we'll
// subtract that contribution from the total amount flushed to
// preserve the abstraction of returning the number of plaintext
// bytes written by the connection.
//
// There are three possible scenarios we must handle to ensure
// the returned value is correct. In the first case, the write
// straddles both payload and MAC bytes, and we must subtract
// the number of MAC bytes written from n. In the second, only
// payload bytes are written, thus we can return n unmodified.
// The final scenario pertains to the case where only MAC bytes
// are written, none of which count towards the total.
//
// |-----------Payload------------|----MAC----|
// Straddle: S---------------------------------E--------0
// Payload-only: S------------------------E-----------------0
// MAC-only: S-------E-0
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
if bodyWritten > 0 {
start := bodyWritten + len(b.nextBodySend)
end := len(b.nextBodySend)
switch {

// Straddles payload and MAC bytes, subtract number of MAC bytes
// written from the actual number written.
case start > macSize && end <= macSize:
nn = n - (macSize - end)
nn = bodyWritten - (macSize - end)

// Only payload bytes are written, return n directly.
// Only payload bytes are written, return bodyWritten directly.
case start > macSize && end > macSize:
nn = n
nn = bodyWritten

// Only MAC bytes are written, return 0 bytes written.
default:
}
}

if err != nil {
return nn, err
}
// If there was an error, return with what we've calculated so far. The
// caller can retry and we'll continue from where we left off.
if err != nil {
return nn, err
}

// If both header and body have been fully flushed, release the pooled
// buffers back to their pools.
//
// TODO(roasbeef): move above again?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO comment seems to be a leftover from development. Moving releaseBuffers() before the error check would introduce a bug: in case of a partial write with an error, the buffers would be released prematurely, preventing a successful retry on the next Flush() call. The current placement is correct. I suggest removing this comment to avoid future confusion.

if len(b.nextHeaderSend) == 0 && len(b.nextBodySend) == 0 {
b.releaseBuffers()
}
Expand Down
132 changes: 132 additions & 0 deletions brontide/noise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,135 @@
t.Fatalf("expected n: %d, got: %d", expN, nn)
}
}

// writeRecorder records all data written to it, tracking each Write call
// separately to verify write coalescing behavior.
type writeRecorder struct {
writes [][]byte
}

func (w *writeRecorder) Write(p []byte) (int, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the style guide (rule 17), every function should be commented with its purpose. This method is missing a comment. Also, rule 18 states that function comments must begin with the function name.

I suggest adding a comment like:

// Write records the written bytes. A copy of the slice is taken to avoid
// issues with buffer reuse.
Suggested change
func (w *writeRecorder) Write(p []byte) (int, error) {
// Write records the written bytes. A copy of the slice is taken to avoid
// issues with buffer reuse.
func (w *writeRecorder) Write(p []byte) (int, error) {
References
  1. Rule 17: Every function must be commented with its purpose and assumptions. Rule 18: Function comments must begin with the function name. (link)

// Make a copy of the data to avoid issues with buffer reuse.
w.writes = append(w.writes, append([]byte{}, p...))

return len(p), nil
}

// totalBytes returns the total number of bytes written across all Write calls.
func (w *writeRecorder) totalBytes() int {
Comment on lines +717 to +718

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the style guide (rule 18), function comments must begin with the function name. This comment starts with a lowercase letter.

Suggested change
// totalBytes returns the total number of bytes written across all Write calls.
func (w *writeRecorder) totalBytes() int {
// TotalBytes returns the total number of bytes written across all Write calls.
func (w *writeRecorder) totalBytes() int {
References
  1. Rule 18: Function comments must begin with the function name. (link)

total := 0
for _, write := range w.writes {
total += len(write)
}

return total
}

// TestFlushCoalescedWrite verifies that the Flush method writes header and body
// data correctly. Note: For non-TCP writers (like bytes.Buffer), net.Buffers
// falls back to sequential writes.
func TestFlushCoalescedWrite(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
payloadSize int
}{
{
name: "small message",
payloadSize: 10,
},
{
name: "medium message",
payloadSize: 500,
},
{
name: "large message",
payloadSize: 1400,
},
{
name: "max message",
payloadSize: math.MaxUint16,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var b Machine
b.split()

payload := bytes.Repeat([]byte("x"), tc.payloadSize)
err := b.WriteMessage(payload)
require.NoError(t, err)

recorder := &writeRecorder{}
n, err := b.Flush(recorder)
require.NoError(t, err)

// The returned value should be the plaintext payload size.

Check failure on line 768 in brontide/noise_test.go

View workflow job for this annotation

GitHub Actions / Lint code

the line is 83 characters long, which exceeds the maximum of 80 characters. (ll)
require.Equal(
t, tc.payloadSize, n,
"returned byte count should match payload size",
)

// Total bytes written should be header + payload + MAC.
expectedTotal := encHeaderSize + tc.payloadSize +
macSize
require.Equal(
t, expectedTotal, recorder.totalBytes(),
"total written bytes should be "+
"header + payload + MAC",
)

// Verify that subsequent Flush is a no-op.
n, err = b.Flush(recorder)
require.NoError(t, err)
require.Equal(
t, 0, n,
"subsequent flush should be no-op",
)
})
}
}

// TestFlushPartialWriteRecovery specifically tests that partial writes
// can be recovered by calling Flush again, and that the final result
// is correct.
func TestFlushPartialWriteRecovery(t *testing.T) {
t.Parallel()

// Create a test connection to get properly paired encrypt/decrypt
// machines.
localConn, remoteConn, err := establishTestConnection(t)
require.NoError(t, err)

localBrontide := localConn.(*Conn)

Check failure on line 805 in brontide/noise_test.go

View workflow job for this annotation

GitHub Actions / Lint code

type assertion must be checked (forcetypeassert)

payload := []byte("hello world")
err = localBrontide.noise.WriteMessage(payload)
require.NoError(t, err)

// First, write only part of the header using a timeout writer.
var partialOutput bytes.Buffer
partialWriter := NewTimeoutWriter(&partialOutput, 5)

n, err := localBrontide.noise.Flush(partialWriter)
require.ErrorIs(t, err, iotest.ErrTimeout)
require.Equal(t, 0, n, "no payload bytes should be reported yet")

// Now write the rest directly to the underlying connection. We need to
// write the partial output first, then flush the rest.
_, err = localBrontide.conn.Write(partialOutput.Bytes())
require.NoError(t, err)

// Flush remaining bytes to the connection.
n, err = localBrontide.noise.Flush(localBrontide.conn)
require.NoError(t, err)
require.Equal(t, len(payload), n)

// Read and verify the message on the remote end.
buf := make([]byte, len(payload))
_, err = io.ReadFull(remoteConn, buf)
require.NoError(t, err)
require.Equal(t, payload, buf)
}
Loading