Skip to content

Commit

Permalink
feat: improve content type handling
Browse files Browse the repository at this point in the history
Parse content type for the file from multipart form data request

Add security headers to avoid malicious use of the application when
serving an html file
  • Loading branch information
Ajnasz committed May 22, 2024
1 parent 4b56708 commit 29e21b3
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 23 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ POSTGRES_URL="postgres://postgres:password@localhost:5432/sekret_link_test?sslmo

### Send and receive data
```sh
curl -v --data-binary @go.mod localhost:8080/api/ | xargs -I {} curl localhost:8080{}
curl -v -H 'content-type: text/plain' --data-binary @go.mod localhost:8080/api/ | xargs -I {} curl localhost:8080{}
```

```sh
curl -v -F 'secret=@README.md;type=text/x-markdown' localhost:8080/api/ | xargs -I {} curl -v localhost:8080{}
```
6 changes: 3 additions & 3 deletions internal/api/createentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"errors"
"log"
"net/http"
"time"

Expand Down Expand Up @@ -60,11 +61,9 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error {
return errors.Join(ErrRequestParseError, err)
}

contentType := r.Header.Get("Content-Type")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
entry, key, err := c.entryManager.CreateEntry(ctx, contentType, data.Body, data.MaxReads, data.Expiration)
entry, key, err := c.entryManager.CreateEntry(ctx, data.ContentType, data.Body, data.MaxReads, data.Expiration)

if err != nil {
return err
Expand All @@ -81,6 +80,7 @@ func (c CreateHandler) Handle(w http.ResponseWriter, r *http.Request) {
err := c.handle(w, r)

if err != nil {
log.Println("create error", err)
c.view.RenderError(w, r, err)
}
}
10 changes: 8 additions & 2 deletions internal/api/createentry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -38,6 +39,7 @@ func (m *MockEntryManager) CreateEntry(
maxReads int,
expiration time.Duration,
) (*services.EntryMeta, key.Key, error) {
fmt.Println("content type", contentType)
args := m.Called(ctx, contentType, body, maxReads, expiration)

if args.Get(1) == nil {
Expand Down Expand Up @@ -72,7 +74,9 @@ func Test_CreateEntryHandle(t *testing.T) {
request.Header.Set("Content-Type", "text/plain")
response := httptest.NewRecorder()

parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil)
parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{
ContentType: "text/plain",
}, nil)

retKey, err := key.NewGeneratedKey()
if err != nil {
Expand Down Expand Up @@ -126,7 +130,9 @@ func Test_CreateEntryHandleError(t *testing.T) {
request.Header.Set("Content-Type", "text/plain")
response := httptest.NewRecorder()

parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil)
parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{
ContentType: "text/plain",
}, nil)
k, err := key.NewGeneratedKey()
assert.NoError(t, err)
entryManager.On("CreateEntry", mock.Anything, "text/plain", mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *k, errors.New("error"))
Expand Down
43 changes: 27 additions & 16 deletions internal/parsers/createentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,41 @@ type CreateEntryParser struct {
}

type CreateEntryRequestData struct {
Body []byte
Expiration time.Duration
MaxReads int
ContentType string
Body []byte
Expiration time.Duration
MaxReads int
}

func NewCreateEntryParser(maxExpireSeconds int) CreateEntryParser {
return CreateEntryParser{maxExpireSeconds: maxExpireSeconds}
}

func parseMultiForm(r *http.Request) ([]byte, error) {
func parseMultiForm(r *http.Request) ([]byte, string, error) {
err := r.ParseMultipartForm(1024 * 1024)
if err != nil {
return nil, err
return nil, "", err
}

secret := r.PostForm.Get("secret")
if secret != "" {
body := []byte(secret)
return body, nil
return body, "text/plain", nil
}

file, _, err := r.FormFile("secret")
file, header, err := r.FormFile("secret")
contentType := header.Header.Get("Content-Type")

if err != nil {
return nil, err
return nil, "", err
}

return io.ReadAll(file)
data, err := io.ReadAll(file)

return data, contentType, err
}

func getBody(r *http.Request) ([]byte, error) {
func getContentType(r *http.Request) string {
ct := r.Header.Get("content-type")
if ct == "" {
ct = "application/octet-stream"
Expand All @@ -54,14 +58,20 @@ func getBody(r *http.Request) ([]byte, error) {
ct, _, err := mime.ParseMediaType(ct)

if err != nil {
return nil, err
return "application/octet-stream"
}

return ct
}

func getContent(r *http.Request) ([]byte, string, error) {
ct := getContentType(r)
switch {
case ct == "multipart/form-data":
return parseMultiForm(r)
default:
return io.ReadAll(r.Body)
data, err := io.ReadAll(r.Body)
return data, ct, err
}
}

Expand Down Expand Up @@ -98,7 +108,7 @@ func (c CreateEntryParser) getSecretMaxReads(r *http.Request) (int, error) {
}

func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, error) {
body, err := getBody(r)
body, contentType, err := getContent(r)

if err != nil {
return nil, err
Expand All @@ -121,9 +131,10 @@ func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, erro
}

return &CreateEntryRequestData{
Body: body,
Expiration: expiration,
MaxReads: maxReads,
ContentType: contentType,
Body: body,
Expiration: expiration,
MaxReads: maxReads,
}, nil

}
7 changes: 6 additions & 1 deletion internal/views/entryread.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ func NewEntryReadView() EntryReadView {

func (e EntryReadView) Render(w http.ResponseWriter, r *http.Request, response EntryReadResponse) {
if r.Header.Get("Accept") == "application/json" {
w.Header().Add("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Println("JSON encode failed", err)
}
} else {
if response.ContentType != "" {
w.Header().Add("Content-Type", response.ContentType)
headers := w.Header()
headers.Add("Content-Type", response.ContentType)
headers.Add("X-Frame-Options", "deny")
headers.Add("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline'; img-src 'self' data:; frame-ancestors 'none'; upgrade-insecure-requests; sandbox;")

}
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(response.Data))
Expand Down

0 comments on commit 29e21b3

Please sign in to comment.