diff --git a/brontide/noise.go b/brontide/noise.go index a1b7cd4dd94..148d70142db 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "math" + "net" "sync" "time" @@ -821,71 +822,92 @@ func (b *Machine) WriteMessage(p []byte) error { // // NOTE: It is safe to call this method again iff a timeout error is returned. func (b *Machine) Flush(w io.Writer) (int, error) { - // First, write out the pending header bytes, if any exist. Any header - // bytes written will not count towards the total amount flushed. - if len(b.nextHeaderSend) > 0 { - // Write any remaining header bytes and shift the slice to point - // to the next segment of unwritten bytes. If an error is - // encountered, we can continue to write the header from where - // we left off on a subsequent call to Flush. - n, err := w.Write(b.nextHeaderSend) - b.nextHeaderSend = b.nextHeaderSend[n:] - if err != nil { - return 0, err - } - } - - // Next, write the pending body bytes, if any exist. Only the number of - // bytes written that correspond to the ciphertext will be included in - // the total bytes written, bytes written as part of the MAC will not be - // counted. + // If there's nothing to flush, return early. + if len(b.nextHeaderSend) == 0 && len(b.nextBodySend) == 0 { + return 0, nil + } + + // Record the initial header length so we can calculate how much of each + // buffer was written in case of a partial write. + headerLen := len(b.nextHeaderSend) + + // Use net.Buffers to perform a vectored write (writev syscall) that + // combines both header and body into a single TCP segment. This avoids + // the inefficiency of sending two separate TCP packets for each + // brontide message. The WriteTo method will use writev on supported + // platforms, falling back to sequential writes otherwise. + buffers := net.Buffers{b.nextHeaderSend, b.nextBodySend} + n64, err := buffers.WriteTo(w) + + n := int(n64) + + // Calculate how much was written from header vs body. The total bytes + // written (n) spans both buffers sequentially. + var headerWritten, bodyWritten int + switch { + // A partial write within the header only, with none of the body being + // written. + case n <= headerLen: + headerWritten = n + bodyWritten = 0 + + // All of header written, plus some (or all) of body. + default: + headerWritten = headerLen + bodyWritten = n - headerLen + } + + // Update the slice offsets to allow resuming from where we left off on + // a subsequent call to Flush in case of a timeout error. + b.nextHeaderSend = b.nextHeaderSend[headerWritten:] + b.nextBodySend = b.nextBodySend[bodyWritten:] + + // Calculate the number of plaintext bytes flushed, excluding MAC bytes. + // Only bytes written from the body count towards the return value, and + // we must exclude any MAC bytes that were written. + // + // There are three possible scenarios we must handle to ensure the + // returned value is correct. In the first case, the write straddles + // both payload and MAC bytes, and we must subtract the number of MAC + // bytes written. In the second, only payload bytes are written, thus we + // return bodyWritten unmodified. The final scenario pertains to the + // case where only MAC bytes are written, none of which count towards + // the total. + // + // |-----------Payload------------|----MAC----| + // Straddle: S---------------------------------E--------0 + // Payload-only: S------------------------E-----------------0 + // MAC-only: S-------E-0 var nn int - if len(b.nextBodySend) > 0 { - // Write out all bytes excluding the mac and shift the body - // slice depending on the number of actual bytes written. - n, err := w.Write(b.nextBodySend) - b.nextBodySend = b.nextBodySend[n:] - - // If we partially or fully wrote any of the body's MAC, we'll - // subtract that contribution from the total amount flushed to - // preserve the abstraction of returning the number of plaintext - // bytes written by the connection. - // - // There are three possible scenarios we must handle to ensure - // the returned value is correct. In the first case, the write - // straddles both payload and MAC bytes, and we must subtract - // the number of MAC bytes written from n. In the second, only - // payload bytes are written, thus we can return n unmodified. - // The final scenario pertains to the case where only MAC bytes - // are written, none of which count towards the total. - // - // |-----------Payload------------|----MAC----| - // Straddle: S---------------------------------E--------0 - // Payload-only: S------------------------E-----------------0 - // MAC-only: S-------E-0 - start, end := n+len(b.nextBodySend), len(b.nextBodySend) + if bodyWritten > 0 { + start := bodyWritten + len(b.nextBodySend) + end := len(b.nextBodySend) switch { // Straddles payload and MAC bytes, subtract number of MAC bytes // written from the actual number written. case start > macSize && end <= macSize: - nn = n - (macSize - end) + nn = bodyWritten - (macSize - end) - // Only payload bytes are written, return n directly. + // Only payload bytes are written, return bodyWritten directly. case start > macSize && end > macSize: - nn = n + nn = bodyWritten // Only MAC bytes are written, return 0 bytes written. default: } + } - if err != nil { - return nn, err - } + // If there was an error, return with what we've calculated so far. The + // caller can retry and we'll continue from where we left off. + if err != nil { + return nn, err } // If both header and body have been fully flushed, release the pooled // buffers back to their pools. + // + // TODO(roasbeef): move above again? if len(b.nextHeaderSend) == 0 && len(b.nextBodySend) == 0 { b.releaseBuffers() } diff --git a/brontide/noise_test.go b/brontide/noise_test.go index cd4dc4cd9a6..ddfaf048b40 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -700,3 +700,135 @@ func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int, t.Fatalf("expected n: %d, got: %d", expN, nn) } } + +// writeRecorder records all data written to it, tracking each Write call +// separately to verify write coalescing behavior. +type writeRecorder struct { + writes [][]byte +} + +func (w *writeRecorder) Write(p []byte) (int, error) { + // Make a copy of the data to avoid issues with buffer reuse. + w.writes = append(w.writes, append([]byte{}, p...)) + + return len(p), nil +} + +// totalBytes returns the total number of bytes written across all Write calls. +func (w *writeRecorder) totalBytes() int { + total := 0 + for _, write := range w.writes { + total += len(write) + } + + return total +} + +// TestFlushCoalescedWrite verifies that the Flush method writes header and body +// data correctly. Note: For non-TCP writers (like bytes.Buffer), net.Buffers +// falls back to sequential writes. +func TestFlushCoalescedWrite(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + payloadSize int + }{ + { + name: "small message", + payloadSize: 10, + }, + { + name: "medium message", + payloadSize: 500, + }, + { + name: "large message", + payloadSize: 1400, + }, + { + name: "max message", + payloadSize: math.MaxUint16, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var b Machine + b.split() + + payload := bytes.Repeat([]byte("x"), tc.payloadSize) + err := b.WriteMessage(payload) + require.NoError(t, err) + + recorder := &writeRecorder{} + n, err := b.Flush(recorder) + require.NoError(t, err) + + // The returned value should be the plaintext payload size. + require.Equal( + t, tc.payloadSize, n, + "returned byte count should match payload size", + ) + + // Total bytes written should be header + payload + MAC. + expectedTotal := encHeaderSize + tc.payloadSize + + macSize + require.Equal( + t, expectedTotal, recorder.totalBytes(), + "total written bytes should be "+ + "header + payload + MAC", + ) + + // Verify that subsequent Flush is a no-op. + n, err = b.Flush(recorder) + require.NoError(t, err) + require.Equal( + t, 0, n, + "subsequent flush should be no-op", + ) + }) + } +} + +// TestFlushPartialWriteRecovery specifically tests that partial writes +// can be recovered by calling Flush again, and that the final result +// is correct. +func TestFlushPartialWriteRecovery(t *testing.T) { + t.Parallel() + + // Create a test connection to get properly paired encrypt/decrypt + // machines. + localConn, remoteConn, err := establishTestConnection(t) + require.NoError(t, err) + + localBrontide := localConn.(*Conn) + + payload := []byte("hello world") + err = localBrontide.noise.WriteMessage(payload) + require.NoError(t, err) + + // First, write only part of the header using a timeout writer. + var partialOutput bytes.Buffer + partialWriter := NewTimeoutWriter(&partialOutput, 5) + + n, err := localBrontide.noise.Flush(partialWriter) + require.ErrorIs(t, err, iotest.ErrTimeout) + require.Equal(t, 0, n, "no payload bytes should be reported yet") + + // Now write the rest directly to the underlying connection. We need to + // write the partial output first, then flush the rest. + _, err = localBrontide.conn.Write(partialOutput.Bytes()) + require.NoError(t, err) + + // Flush remaining bytes to the connection. + n, err = localBrontide.noise.Flush(localBrontide.conn) + require.NoError(t, err) + require.Equal(t, len(payload), n) + + // Read and verify the message on the remote end. + buf := make([]byte, len(payload)) + _, err = io.ReadFull(remoteConn, buf) + require.NoError(t, err) + require.Equal(t, payload, buf) +}