Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion s3proxy/internal/router/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,21 @@ func handleGetObject(client *s3.Client, key string, bucket string, log *logger.L
func handlePutObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
log.WithField("path", req.URL.Path).WithField("method", req.Method).WithField("host", req.Host).Debug("intercepting")
body, err := io.ReadAll(req.Body)
var (
body []byte
err error
)
if req.ContentLength > 0 {
n := int(req.ContentLength)
// Preallocate the buffer from Content-Length and fill it with io.ReadFull.
// This avoids the incremental growth and extra copies that io.ReadAll incurs
// when the final size is unknown, which can blow up RAM on large payloads.
// If Content-Length is missing or bogus we fall back to ReadAll below.
body = make([]byte, n)
_, err = io.ReadFull(req.Body, body)
} else {
body, err = io.ReadAll(req.Body)
}
if err != nil {
log.WithField("error", err).Error("PutObject")
http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError)
Expand Down Expand Up @@ -113,6 +127,7 @@ func handlePutObject(client *s3.Client, key string, bucket string, log *logger.L
}

put(obj.put)(w, req)
defer req.Body.Close()
}
}

Expand Down
140 changes: 82 additions & 58 deletions s3proxy/internal/router/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http"
"net/url"
"regexp"
"runtime/debug"
"strconv"
"strings"
"syscall"
Expand Down Expand Up @@ -55,6 +56,17 @@ type object struct {
log *logger.Logger
}

const freeOSMemoryThreshold int = 100 * 1024 * 1024 // 100 MiB

func setHeaderIfNonEmpty(h http.Header, key string, val *string) {
if val != nil {
v := strings.TrimSpace(*val)
if v != "" {
h.Set(key, v)
}
}
}

// get is a http.HandlerFunc that implements the GET method for objects.
func (o object) get(w http.ResponseWriter, r *http.Request) {
requestID := uuid.New().String()
Expand Down Expand Up @@ -101,40 +113,44 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {
if output.Expiration != nil {
w.Header().Set("x-amz-expiration", *output.Expiration)
}
if output.ChecksumCRC32 != nil {
w.Header().Set("x-amz-checksum-crc32", *output.ChecksumCRC32)
}
if output.ChecksumCRC32C != nil {
w.Header().Set("x-amz-checksum-crc32c", *output.ChecksumCRC32C)
}
if output.ChecksumSHA1 != nil {
w.Header().Set("x-amz-checksum-sha1", *output.ChecksumSHA1)
}
if output.ChecksumSHA256 != nil {
w.Header().Set("x-amz-checksum-sha256", *output.ChecksumSHA256)
}
if output.SSECustomerAlgorithm != nil {
w.Header().Set("x-amz-server-side-encryption-customer-algorithm", *output.SSECustomerAlgorithm)
}
if output.SSECustomerKeyMD5 != nil {
w.Header().Set("x-amz-server-side-encryption-customer-key-MD5", *output.SSECustomerKeyMD5)
}
if output.SSEKMSKeyId != nil {
w.Header().Set("x-amz-server-side-encryption-aws-kms-key-id", *output.SSEKMSKeyId)
}
if output.ServerSideEncryption != "" {
w.Header().Set("x-amz-server-side-encryption-context", string(output.ServerSideEncryption))
}
setHeaderIfNonEmpty(w.Header(), "x-amz-expiration", output.Expiration)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-crc32", output.ChecksumCRC32)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-crc32c", output.ChecksumCRC32C)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-sha1", output.ChecksumSHA1)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-sha256", output.ChecksumSHA256)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-customer-algorithm", output.SSECustomerAlgorithm)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-customer-key-MD5", output.SSECustomerKeyMD5)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-aws-kms-key-id", output.SSEKMSKeyId)

body, err := io.ReadAll(output.Body)
if err != nil {
o.log.WithField("requestID", requestID).WithField("error", err).Error("GetObject reading S3 response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
var body []byte
if output.ContentLength == nil {
// fallback on io.ReadAll if ContentLength is unknown
body, err = io.ReadAll(output.Body)
if err != nil {
o.log.WithField("requestID", requestID).WithField("error", err).Error("GetObject reading S3 response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
n := int(*output.ContentLength)
// Preallocate the buffer from Content-Length and fill it with io.ReadFull.
// This avoids the incremental growth and extra copies that io.ReadAll incurs
// when the final size is unknown, which can blow up RAM on large payloads.
// If Content-Length is missing or bogus we fall back to ReadAll below.
body = make([]byte, n)
if _, err := io.ReadFull(output.Body, body); err != nil {
o.log.WithField("requestID", requestID).WithField("error", err).Error("GetObject reading S3 response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

plaintext := body
rawEncryptedDEK, ok := output.Metadata[dekTag]
defer output.Body.Close()
if ok {
encryptedDEK, err := hex.DecodeString(rawEncryptedDEK)
if err != nil {
Expand All @@ -144,16 +160,27 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {
}

plaintext, err = crypto.Decrypt(body, encryptedDEK, o.kek)
// We do not need to keep body anymore. Because it can be gigabytes in size - free it ASAP
bodyLen := len(body)
body = nil
if bodyLen >= freeOSMemoryThreshold {
debug.FreeOSMemory()
}
if err != nil {
o.log.WithField("requestID", requestID).WithField("error", err).Error("GetObject decrypting response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

plaintextLen := len(plaintext)
select {
case <-r.Context().Done():
o.log.WithField("requestID", requestID).Info("Request was canceled by client")
plaintext = nil
if plaintextLen >= freeOSMemoryThreshold {
debug.FreeOSMemory()
}
return
default:
w.WriteHeader(http.StatusOK)
Expand All @@ -165,6 +192,10 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {
}
}
}
plaintext = nil
if plaintextLen >= freeOSMemoryThreshold {
debug.FreeOSMemory()
}
}

// put is a http.HandlerFunc that implements the PUT method for objects.
Expand All @@ -179,6 +210,12 @@ func (o object) put(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// We do not need to keep data anymore. Because it can be gigabytes in size - free it ASAP
dataLen := len(o.data)
o.data = nil
if dataLen >= freeOSMemoryThreshold {
debug.FreeOSMemory()
}
o.metadata[dekTag] = hex.EncodeToString(encryptedDEK)

output, err := o.client.PutObject(context.WithoutCancel(r.Context()), o.bucket, o.key, o.tags, o.contentType, o.objectLockLegalHoldStatus, o.objectLockMode, o.sseCustomerAlgorithm, o.sseCustomerKey, o.sseCustomerKeyMD5, o.objectLockRetainUntilDate, o.metadata, ciphertext)
Expand All @@ -195,42 +232,29 @@ func (o object) put(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set("x-amz-server-side-encryption", string(output.ServerSideEncryption))

if output.VersionId != nil {
w.Header().Set("x-amz-version-id", *output.VersionId)
cipherTextLen := len(ciphertext)
ciphertext = nil
if cipherTextLen > freeOSMemoryThreshold {
debug.FreeOSMemory()
}
ssd_enc := string(output.ServerSideEncryption)
if ssd_enc != "" { // It can be empty for empty files, at least on Hetzner storage
w.Header().Set("x-amz-server-side-encryption", ssd_enc)
}
if output.ETag != nil {
w.Header().Set("ETag", strings.Trim(*output.ETag, "\""))
}
if output.Expiration != nil {
w.Header().Set("x-amz-expiration", *output.Expiration)
}
if output.ChecksumCRC32 != nil {
w.Header().Set("x-amz-checksum-crc32", *output.ChecksumCRC32)
}
if output.ChecksumCRC32C != nil {
w.Header().Set("x-amz-checksum-crc32c", *output.ChecksumCRC32C)
}
if output.ChecksumSHA1 != nil {
w.Header().Set("x-amz-checksum-sha1", *output.ChecksumSHA1)
}
if output.ChecksumSHA256 != nil {
w.Header().Set("x-amz-checksum-sha256", *output.ChecksumSHA256)
}
if output.SSECustomerAlgorithm != nil {
w.Header().Set("x-amz-server-side-encryption-customer-algorithm", *output.SSECustomerAlgorithm)
}
if output.SSECustomerKeyMD5 != nil {
w.Header().Set("x-amz-server-side-encryption-customer-key-MD5", *output.SSECustomerKeyMD5)
}
if output.SSEKMSKeyId != nil {
w.Header().Set("x-amz-server-side-encryption-aws-kms-key-id", *output.SSEKMSKeyId)
}
if output.SSEKMSEncryptionContext != nil {
w.Header().Set("x-amz-server-side-encryption-context", *output.SSEKMSEncryptionContext)
}

setHeaderIfNonEmpty(w.Header(), "x-amz-version-id", output.VersionId)
setHeaderIfNonEmpty(w.Header(), "x-amz-expiration", output.Expiration)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-crc32", output.ChecksumCRC32)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-crc32c", output.ChecksumCRC32C)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-sha1", output.ChecksumSHA1)
setHeaderIfNonEmpty(w.Header(), "x-amz-checksum-sha256", output.ChecksumSHA256)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-customer-algorithm", output.SSECustomerAlgorithm)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-customer-key-MD5", output.SSECustomerKeyMD5)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-aws-kms-key-id", output.SSEKMSKeyId)
setHeaderIfNonEmpty(w.Header(), "x-amz-server-side-encryption-context", output.SSEKMSEncryptionContext)

w.WriteHeader(http.StatusOK)
if _, err := w.Write(nil); err != nil {
Expand Down
57 changes: 42 additions & 15 deletions s3proxy/internal/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/base64"
"fmt"
"io"
"strconv"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -60,21 +61,47 @@ func addCaptureRawResponseDeserializeMiddleware() func(*middleware.Stack) error
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
if resp, ok := out.RawResponse.(*smithyhttp.Response); ok {
// Clone the response body
var buf bytes.Buffer
body := resp.Body
tee := io.NopCloser(io.TeeReader(body, &buf))

// Replace the body in the response with the cloned body
resp.Body = tee

bodyBytes, _ := io.ReadAll(resp.Body)

// Store the cloned body in metadata
metadata.Set(RawResponseKey{}, string(bodyBytes))

// Restore the original body for further processing
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// It is better not to clone the response body for successful responses
// because it can consume a lot of memory for large responses and we can not free it ASAP
if resp.StatusCode >= 400 {
var bodyBytes []byte

if cl := resp.Header.Get("Content-Length"); cl != "" {
if n64, perr := strconv.ParseInt(cl, 10, 64); perr == nil && n64 >= 0 {
n := int(n64)
bodyBytes = make([]byte, n)
// Preallocate the buffer from Content-Length and fill it with io.ReadFull.
// This avoids the incremental growth and extra copies that io.ReadAll incurs
// when the final size is unknown, which can blow up RAM on large payloads.
// If Content-Length is missing or bogus we fall back to ReadAll below.
if _, rerr := io.ReadFull(resp.Body, bodyBytes); rerr != nil {
wrap := fmt.Errorf("capture raw response (prealloc) failed: %w", rerr)
if err != nil {
return out, metadata, fmt.Errorf("%v; original deserialize error: %w", wrap, err)
}
return out, metadata, wrap
}
}
}

if bodyBytes == nil {
// Fallback: previous behavior (unbounded ReadAll).
// NOTE: this may allocate for large bodies; we only use it when CL is missing/invalid.
b, rerr := io.ReadAll(resp.Body)
if rerr != nil {
wrap := fmt.Errorf("capture raw response (ReadAll) failed: %w", rerr)
if err != nil {
return out, metadata, fmt.Errorf("%v; original deserialize error: %w", wrap, err)
}
return out, metadata, wrap
}
bodyBytes = b
}
// Restore the original body for further processing
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
} else {
metadata.Set(RawResponseKey{}, "")
}
return out, metadata, err
}), middleware.After)
Expand Down