Skip to content

Commit 4a9598b

Browse files
committed
Content type and duplicate headers
Signed-off-by: Trevor Bramwell <tbramwell@linuxfoundation.org>
1 parent 62ff146 commit 4a9598b

File tree

2 files changed

+67
-13
lines changed

2 files changed

+67
-13
lines changed

pkg/middlewares/awslambda/aws_lambda.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {
139139
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
140140
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))
141141

142-
base64Encoded, reqBody, err := bodyToBase64(req)
142+
base64Encoded, contentType, body, err := bodyToBase64(req)
143143
if err != nil {
144144
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
145145
logger.Error(msg)
@@ -149,6 +149,26 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
149149
return
150150
}
151151

152+
// If Content-Type is set, isn't set, assume it's JSON
153+
rCt := req.Header.Get("Content-Type")
154+
switch rCt {
155+
case "":
156+
logger.Debug("Content-Type not set")
157+
if !strings.HasPrefix(contentType, "text") {
158+
logger.Debugf("Content-Type not like text, setting to :%s", contentType)
159+
req.Header.Set("Content-Type", contentType)
160+
} else {
161+
req.Header.Set("Content-Type", "application/json")
162+
}
163+
case "application/x-www-form-urlencoded":
164+
if isJSON(rCt) {
165+
req.Header.Set("Content-Type", "application/json")
166+
}
167+
default:
168+
req.Header.Set("Content-Type", "application/json")
169+
}
170+
logger.Debugf("Content-Type set to: %s, originally %s", req.Header.Get("Content-Type"), rCt)
171+
152172
// Ensure tracing headers are included in the request before copying
153173
// them to the lambda request
154174
tracing.InjectRequestHeaders(req)
@@ -205,6 +225,12 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
205225
}
206226

207227
for key, values := range resp.MultiValueHeaders {
228+
// NOTE This maybe specific to Content-Type, but it's listed in
229+
// headers and multivalue headers so it ends up getting added twice.
230+
// Is a multivalue header with only one item really multivalue?
231+
if len(values) < 2 {
232+
continue
233+
}
208234
for _, value := range values {
209235
rw.Header().Add(key, value)
210236
}
@@ -234,7 +260,8 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
234260
}
235261

236262
// bodyToBase64 ensures the request body is base64 encoded.
237-
func bodyToBase64(req *http.Request) (bool, string, error) {
263+
func bodyToBase64(req *http.Request) (bool, string, string, error) {
264+
contentType := ""
238265
base64Encoded := false
239266
body := ""
240267
// base64 encode non-text request body
@@ -246,15 +273,15 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
246273
// Read the request body and reset it to be read again if needed
247274
bodyBytes, err := io.ReadAll(req.Body)
248275
if err != nil {
249-
return base64Encoded, body, err
276+
return base64Encoded, contentType, body, err
250277
}
251278
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
252279

253280
body = string(bodyBytes)
254281

255282
// Any non 'text/*' MIME types should be base64 encoded.
256283
// DetectContentType does not check for 'application/json'
257-
contentType := http.DetectContentType(bodyBytes)
284+
contentType = http.DetectContentType(bodyBytes)
258285
if !strings.HasPrefix(contentType, "text") {
259286
base64Encoded = true
260287
}
@@ -267,17 +294,17 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
267294

268295
_, err := io.Copy(encoder, bytes.NewReader(bodyBytes))
269296
if err != nil {
270-
return base64Encoded, body, err
297+
return base64Encoded, contentType, body, err
271298
}
272299
if err = encoder.Close(); err != nil {
273-
return base64Encoded, body, err
300+
return base64Encoded, contentType, body, err
274301
}
275302
// Set body to b64 encoded version
276303
body = b64buf.String()
277304
}
278305
}
279306

280-
return base64Encoded, body, nil
307+
return base64Encoded, contentType, body, nil
281308
}
282309

283310
func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {
@@ -429,3 +456,10 @@ func valuesToMultiMap(i url.Values) map[string][]string {
429456

430457
return values
431458
}
459+
460+
// Check if a string looks like JSON
461+
func isJSON(s string) bool {
462+
var js interface{}
463+
return json.Unmarshal([]byte(s), &js) == nil
464+
465+
}

pkg/middlewares/awslambda/aws_lambda_test.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
108108
assert.Equal(t, "/test/example/path", lReq.Path)
109109
assert.Equal(t, map[string]string{"a": "1", "b": "2"}, lReq.QueryStringParameters)
110110
assert.Equal(t, map[string][]string{"c": {"3", "4"}, "d[]": {"5", "6"}}, lReq.MultiValueQueryStringParameters)
111-
assert.Equal(t, map[string]string{"Content-Type": "text/plain"}, lReq.Headers)
111+
assert.Equal(t, map[string]string{"Content-Type": "application/json"}, lReq.Headers)
112112
assert.Equal(t, map[string][]string{"X-Test": {"foo", "foobar"}}, lReq.MultiValueHeaders)
113113
assert.Equal(t, "This is the body", lReq.Body)
114114

@@ -144,7 +144,7 @@ func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
144144
if err != nil {
145145
t.Fatal(err)
146146
}
147-
req.Header.Set("Content-Type", "text/plain")
147+
req.Header.Set("Content-Type", "application/json")
148148
req.Header.Add("X-Test", "foo")
149149
req.Header.Add("X-Test", "foobar")
150150

@@ -178,10 +178,11 @@ func Test_AWSLambdaMiddleware_GetTracingInformation(t *testing.T) {
178178
func Test_AWSLambdaMiddleware_bodyToBase64_empty(t *testing.T) {
179179
req, err := http.NewRequest(http.MethodGet, "/", nil)
180180
require.NoError(t, err)
181-
isEncoded, body, err := bodyToBase64(req)
181+
isEncoded, contentType, body, err := bodyToBase64(req)
182182

183183
assert.False(t, isEncoded)
184184
assert.Equal(t, "", body)
185+
assert.Equal(t, "", contentType)
185186
require.NoError(t, err)
186187
}
187188

@@ -191,10 +192,27 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) {
191192

192193
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
193194
require.NoError(t, err)
194-
isEncoded, body, err := bodyToBase64(req)
195+
isEncoded, contentType, body, err := bodyToBase64(req)
195196

196197
assert.False(t, isEncoded)
197198
assert.Equal(t, reqBody, body)
199+
assert.Equal(t, "text/plain; charset=utf-8", contentType)
200+
require.NoError(t, err)
201+
}
202+
203+
func Test_AWSLambdaMiddleware_bodyToBase64_EncodedJSON(t *testing.T) {
204+
bodyBytes, err := json.Marshal(`{"test": "encoded"}`)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
209+
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(string(bodyBytes)))
210+
require.NoError(t, err)
211+
isEncoded, contentType, body, err := bodyToBase64(req)
212+
213+
assert.False(t, isEncoded)
214+
assert.Equal(t, string(bodyBytes), body)
215+
assert.Equal(t, "text/plain; charset=utf-8", contentType)
198216
require.NoError(t, err)
199217
}
200218

@@ -206,10 +224,11 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {
206224

207225
req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody))
208226
require.NoError(t, err)
209-
isEncoded, body, err := bodyToBase64(req)
227+
isEncoded, contentType, body, err := bodyToBase64(req)
210228

211229
assert.True(t, isEncoded)
212230
assert.Equal(t, expected, body)
231+
assert.Equal(t, "application/zip", contentType)
213232
require.NoError(t, err)
214233

215234
// image/jpeg
@@ -218,9 +237,10 @@ func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {
218237

219238
req2, err2 := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody2))
220239
require.NoError(t, err2)
221-
isEncoded2, body2, err2 := bodyToBase64(req2)
240+
isEncoded2, contentType2, body2, err2 := bodyToBase64(req2)
222241

223242
assert.True(t, isEncoded2)
224243
assert.Equal(t, expected2, body2)
244+
assert.Equal(t, "image/jpeg", contentType2)
225245
require.NoError(t, err2)
226246
}

0 commit comments

Comments
 (0)