diff --git a/README.md b/README.md index f385cb3..5d644d6 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,9 @@ POSTGRES_URL="postgres://postgres:password@localhost:5432/sekret_link_test?sslmo ### Send and receive data ```sh -curl -v --data-binary @go.mod localhost:8080/api/ | xargs -I {} curl localhost:8080{} +curl -v -H 'content-type: text/plain' --data-binary @go.mod localhost:8080/api/ | xargs -I {} curl localhost:8080{} ``` +```sh +curl -v -F 'secret=@README.md;type=text/x-markdown' localhost:8080/api/ | xargs -I {} curl -v localhost:8080{} +``` diff --git a/internal/api/createentry.go b/internal/api/createentry.go index d34f97e..9f02366 100644 --- a/internal/api/createentry.go +++ b/internal/api/createentry.go @@ -3,6 +3,7 @@ package api import ( "context" "errors" + "log" "net/http" "time" @@ -60,11 +61,9 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error { return errors.Join(ErrRequestParseError, err) } - contentType := r.Header.Get("Content-Type") - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - entry, key, err := c.entryManager.CreateEntry(ctx, contentType, data.Body, data.MaxReads, data.Expiration) + entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, data.MaxReads, data.Expiration) if err != nil { return err @@ -81,6 +80,7 @@ func (c CreateHandler) Handle(w http.ResponseWriter, r *http.Request) { err := c.handle(w, r) if err != nil { + log.Println("create error", err) c.view.RenderError(w, r, err) } } diff --git a/internal/api/createentry_test.go b/internal/api/createentry_test.go index 0cd59b9..52c62c3 100644 --- a/internal/api/createentry_test.go +++ b/internal/api/createentry_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -38,6 +39,7 @@ func (m *MockEntryManager) CreateEntry( maxReads int, expiration time.Duration, ) (*services.EntryMeta, key.Key, error) { + fmt.Println("content type", contentType) args := m.Called(ctx, contentType, body, maxReads, expiration) if args.Get(1) == nil { @@ -72,7 +74,9 @@ func Test_CreateEntryHandle(t *testing.T) { request.Header.Set("Content-Type", "text/plain") response := httptest.NewRecorder() - parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil) + parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{ + ContentType: "text/plain", + }, nil) retKey, err := key.NewGeneratedKey() if err != nil { @@ -126,7 +130,9 @@ func Test_CreateEntryHandleError(t *testing.T) { request.Header.Set("Content-Type", "text/plain") response := httptest.NewRecorder() - parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil) + parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{ + ContentType: "text/plain", + }, nil) k, err := key.NewGeneratedKey() assert.NoError(t, err) entryManager.On("CreateEntry", mock.Anything, "text/plain", mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *k, errors.New("error")) diff --git a/internal/parsers/createentry.go b/internal/parsers/createentry.go index b6cc0a4..3429c28 100644 --- a/internal/parsers/createentry.go +++ b/internal/parsers/createentry.go @@ -15,37 +15,41 @@ type CreateEntryParser struct { } type CreateEntryRequestData struct { - Body []byte - Expiration time.Duration - MaxReads int + ContentType string + Body []byte + Expiration time.Duration + MaxReads int } func NewCreateEntryParser(maxExpireSeconds int) CreateEntryParser { return CreateEntryParser{maxExpireSeconds: maxExpireSeconds} } -func parseMultiForm(r *http.Request) ([]byte, error) { +func parseMultiForm(r *http.Request) ([]byte, string, error) { err := r.ParseMultipartForm(1024 * 1024) if err != nil { - return nil, err + return nil, "", err } secret := r.PostForm.Get("secret") if secret != "" { body := []byte(secret) - return body, nil + return body, "text/plain", nil } - file, _, err := r.FormFile("secret") + file, header, err := r.FormFile("secret") + contentType := header.Header.Get("Content-Type") if err != nil { - return nil, err + return nil, "", err } - return io.ReadAll(file) + data, err := io.ReadAll(file) + + return data, contentType, err } -func getBody(r *http.Request) ([]byte, error) { +func getContentType(r *http.Request) string { ct := r.Header.Get("content-type") if ct == "" { ct = "application/octet-stream" @@ -54,14 +58,20 @@ func getBody(r *http.Request) ([]byte, error) { ct, _, err := mime.ParseMediaType(ct) if err != nil { - return nil, err + return "application/octet-stream" } + return ct +} + +func getContent(r *http.Request) ([]byte, string, error) { + ct := getContentType(r) switch { case ct == "multipart/form-data": return parseMultiForm(r) default: - return io.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) + return data, ct, err } } @@ -98,7 +108,7 @@ func (c CreateEntryParser) getSecretMaxReads(r *http.Request) (int, error) { } func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, error) { - body, err := getBody(r) + body, contentType, err := getContent(r) if err != nil { return nil, err @@ -121,9 +131,10 @@ func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, erro } return &CreateEntryRequestData{ - Body: body, - Expiration: expiration, - MaxReads: maxReads, + ContentType: contentType, + Body: body, + Expiration: expiration, + MaxReads: maxReads, }, nil } diff --git a/internal/views/entryread.go b/internal/views/entryread.go index d50cc65..f5ce1ed 100644 --- a/internal/views/entryread.go +++ b/internal/views/entryread.go @@ -45,12 +45,17 @@ func NewEntryReadView() EntryReadView { func (e EntryReadView) Render(w http.ResponseWriter, r *http.Request, response EntryReadResponse) { if r.Header.Get("Accept") == "application/json" { + w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { log.Println("JSON encode failed", err) } } else { if response.ContentType != "" { - w.Header().Add("Content-Type", response.ContentType) + headers := w.Header() + headers.Add("Content-Type", response.ContentType) + headers.Add("X-Frame-Options", "deny") + headers.Add("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline'; img-src 'self' data:; frame-ancestors 'none'; upgrade-insecure-requests; sandbox;") + } w.WriteHeader(http.StatusOK) _, err := w.Write([]byte(response.Data))