Skip to content

Fix/validate cookies #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ func (c ContextNoBody) Header(key string) string {
return c.Request().Header.Get(key)
}

// Has request header
func (c ContextNoBody) HasHeader(key string) bool {
return c.Header(key) != ""
}

// Sets response header
func (c ContextNoBody) SetHeader(key, value string) {
c.Response().Header().Set(key, value)
Expand All @@ -215,6 +220,12 @@ func (c ContextNoBody) Cookie(name string) (*http.Cookie, error) {
return c.Request().Cookie(name)
}

// Has request cookie
func (c ContextNoBody) HasCookie(name string) bool {
_, err := c.Cookie(name)
return err == nil
}

// Sets response cookie
func (c ContextNoBody) SetCookie(cookie http.Cookie) {
http.SetCookie(c.Response(), &cookie)
Expand Down
16 changes: 0 additions & 16 deletions ctx_params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,4 @@ func TestParam(t *testing.T) {
require.Equal(t, "hey18true", w.Body.String())
})
})

t.Run("Should enforce Required", func(t *testing.T) {
s := fuego.NewServer()

fuego.Get(s, "/test", func(c fuego.ContextNoBody) (string, error) {
name := c.QueryParam("name")
return name, nil
},
option.Query("name", "Name", param.Required(), param.Example("example1", "you")),
)
r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
s.Mux.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "name is a required query param")
})
}
2 changes: 1 addition & 1 deletion documentation/docs/tutorials/02-crud.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func main() {
s := fuego.NewServer()
// ....

// Declare the ressource
// Declare the resource
booksResources := controllers.BooksResources{
BooksService: controllers.RealBooksService{},
// Other services & dependencies, like a DB etc.
Expand Down
22 changes: 1 addition & 21 deletions serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,13 @@ func (s *Server) proto() string {
return "http"
}

func validateQueryParams(c ContextNoBody) error {
for k, param := range c.params {
if param.Default != nil {
// skip: param has a default
continue
}

if param.Required && !c.urlValues.Has(k) {
err := fmt.Errorf("%s is a required query param", k)
return BadRequestError{
Title: "Query Param Not Found",
Err: err,
Detail: "cannot parse request parameter: " + err.Error(),
}
}
}

return nil
}

// initializes any Context type with the base ContextNoBody context.
//
// var ctx ContextWithBody[any] // does not work because it will create a ContextWithBody[any] with a nil value
func initContext[Contextable ctx[Body], Body any](baseContext ContextNoBody) (Contextable, error) {
var c Contextable

err := validateQueryParams(baseContext)
err := validateParams(baseContext)
if err != nil {
return c, err
}
Expand Down
46 changes: 46 additions & 0 deletions validateParams.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package fuego

import "fmt"

func validateParams(c ContextNoBody) error {
for k, param := range c.params {
if param.Default != nil {
// skip: param has a default
continue
}

if param.Required {
switch param.Type {
case QueryParamType:
if !c.urlValues.Has(k) {
err := fmt.Errorf("%s is a required query param", k)
return BadRequestError{
Title: "Query Param Not Found",
Err: err,
Detail: "cannot parse request parameter: " + err.Error(),
}
}
case HeaderParamType:
if !c.HasHeader(k) {
err := fmt.Errorf("%s is a required header", k)
return BadRequestError{
Title: "Header Not Found",
Err: err,
Detail: "cannot parse request parameter: " + err.Error(),
}
}
case CookieParamType:
if !c.HasCookie(k) {
err := fmt.Errorf("%s is a required cookie", k)
return BadRequestError{
Title: "Cookie Not Found",
Err: err,
Detail: "cannot parse request parameter: " + err.Error(),
}
}
}
}
}

return nil
}
54 changes: 54 additions & 0 deletions validateParams_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package fuego_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"

"github.com/go-fuego/fuego"
"github.com/go-fuego/fuego/option"
"github.com/go-fuego/fuego/param"
)

func TestParamsValidation(t *testing.T) {
t.Run("Should enforce Required query parameter", func(t *testing.T) {
s := fuego.NewServer()

fuego.Get(s, "/test", dummyController,
option.Query("name", "Name", param.Required(), param.Example("example1", "you")),
)
r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
s.Mux.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "name is a required query param")
})

t.Run("Should enforce Required header", func(t *testing.T) {
s := fuego.NewServer()

fuego.Get(s, "/test", dummyController,
option.Header("foo", "header that is foo", param.Required()),
)
r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
s.Mux.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "foo is a required header")
})

t.Run("Should enforce Required cookie", func(t *testing.T) {
s := fuego.NewServer()

fuego.Get(s, "/test", dummyController,
option.Cookie("bar", "cookie that is bar", param.Required()),
)
r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
s.Mux.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "bar is a required cookie")
})
}