diff --git a/events/lambda_function_urls.go b/events/lambda_function_urls.go index 36be4fbe..bef9869d 100644 --- a/events/lambda_function_urls.go +++ b/events/lambda_function_urls.go @@ -2,6 +2,14 @@ package events +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" +) + // LambdaFunctionURLRequest contains data coming from the HTTP request to a Lambda Function URL. type LambdaFunctionURLRequest struct { Version string `json:"version"` // Version is expected to be `"2.0"` @@ -59,3 +67,69 @@ type LambdaFunctionURLResponse struct { IsBase64Encoded bool `json:"isBase64Encoded"` Cookies []string `json:"cookies"` } + +// LambdaFunctionURLStreamingResponse models the response to a Lambda Function URL when InvokeMode is RESPONSE_STREAM. +// If the InvokeMode of the Function URL is BUFFERED (default), use LambdaFunctionURLResponse instead. +// +// Example: +// +// lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) { +// return &events.LambdaFunctionURLStreamingResponse{ +// StatusCode: 200, +// Headers: map[string]string{ +// "Content-Type": "text/html", +// }, +// Body: strings.NewReader("Hello World!"), +// }, nil +// }) +type LambdaFunctionURLStreamingResponse struct { + prelude *bytes.Buffer + + StatusCode int + Headers map[string]string + Body io.Reader + Cookies []string +} + +func (r *LambdaFunctionURLStreamingResponse) Read(p []byte) (n int, err error) { + if r.prelude == nil { + if r.StatusCode == 0 { + r.StatusCode = http.StatusOK + } + b, err := json.Marshal(struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers,omitempty"` + Cookies []string `json:"cookies,omitempty"` + }{ + StatusCode: r.StatusCode, + Headers: r.Headers, + Cookies: r.Cookies, + }) + if err != nil { + return 0, err + } + r.prelude = bytes.NewBuffer(append(b, 0, 0, 0, 0, 0, 0, 0, 0)) + } + if r.prelude.Len() > 0 { + return r.prelude.Read(p) + } + if r.Body == nil { + return 0, io.EOF + } + return r.Body.Read(p) +} + +func (r *LambdaFunctionURLStreamingResponse) Close() error { + if closer, ok := r.Body.(io.ReadCloser); ok { + return closer.Close() + } + return nil +} + +func (r *LambdaFunctionURLStreamingResponse) MarshalJSON() ([]byte, error) { + return nil, errors.New("not json") +} + +func (r *LambdaFunctionURLStreamingResponse) ContentType() string { + return "application/vnd.awslambda.http-integration-response" +} diff --git a/events/lambda_function_urls_test.go b/events/lambda_function_urls_test.go index 0b11a048..cbc15f45 100644 --- a/events/lambda_function_urls_test.go +++ b/events/lambda_function_urls_test.go @@ -4,10 +4,14 @@ package events import ( "encoding/json" + "errors" "io/ioutil" //nolint: staticcheck + "net/http" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLambdaFunctionURLResponseMarshaling(t *testing.T) { @@ -55,3 +59,91 @@ func TestLambdaFunctionURLRequestMarshaling(t *testing.T) { assert.JSONEq(t, string(inputJSON), string(outputJSON)) } + +func TestLambdaFunctionURLStreamingResponseMarshaling(t *testing.T) { + for _, test := range []struct { + name string + response *LambdaFunctionURLStreamingResponse + expectedHead string + expectedBody string + }{ + { + "empty", + &LambdaFunctionURLStreamingResponse{}, + `{"statusCode":200}`, + "", + }, + { + "just the status code", + &LambdaFunctionURLStreamingResponse{ + StatusCode: http.StatusTeapot, + }, + `{"statusCode":418}`, + "", + }, + { + "status and headers and cookies and body", + &LambdaFunctionURLStreamingResponse{ + StatusCode: http.StatusTeapot, + Headers: map[string]string{"hello": "world"}, + Cookies: []string{"cookies", "are", "yummy"}, + Body: strings.NewReader(`Hello Hello`), + }, + `{"statusCode":418, "headers":{"hello":"world"}, "cookies":["cookies","are","yummy"]}`, + `Hello Hello`, + }, + } { + t.Run(test.name, func(t *testing.T) { + response, err := ioutil.ReadAll(test.response) + require.NoError(t, err) + sep := "\x00\x00\x00\x00\x00\x00\x00\x00" + responseParts := strings.Split(string(response), sep) + require.Len(t, responseParts, 2) + head := string(responseParts[0]) + body := string(responseParts[1]) + assert.JSONEq(t, test.expectedHead, head) + assert.Equal(t, test.expectedBody, body) + assert.NoError(t, test.response.Close()) + }) + } +} + +type readCloser struct { + closed bool + err error + reader *strings.Reader +} + +func (r *readCloser) Read(p []byte) (int, error) { + return r.reader.Read(p) +} + +func (r *readCloser) Close() error { + r.closed = true + return r.err +} + +func TestLambdaFunctionURLStreamingResponsePropogatesInnerClose(t *testing.T) { + for _, test := range []struct { + name string + closer *readCloser + err error + }{ + { + "closer no err", + &readCloser{}, + nil, + }, + { + "closer with err", + &readCloser{err: errors.New("yolo")}, + errors.New("yolo"), + }, + } { + t.Run(test.name, func(t *testing.T) { + response := &LambdaFunctionURLStreamingResponse{Body: test.closer} + assert.Equal(t, test.err, response.Close()) + assert.True(t, test.closer.closed) + }) + } +}