diff --git a/context.go b/context.go new file mode 100644 index 0000000..fe47270 --- /dev/null +++ b/context.go @@ -0,0 +1,25 @@ +// +build go1.7 + +package csrf + +import ( + "context" + "net/http" + + "github.com/pkg/errors" +) + +func contextGet(r *http.Request, key string) (interface{}, error) { + val := r.Context().Value(key) + if val == nil { + return nil, errors.Errorf("no value exists in the context for key %q", key) + } + + return val, nil +} + +func contextSave(r *http.Request, key string, val interface{}) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, key, val) + return r.WithContext(ctx) +} diff --git a/context_legacy.go b/context_legacy.go new file mode 100644 index 0000000..dabf0a6 --- /dev/null +++ b/context_legacy.go @@ -0,0 +1,24 @@ +// +build !go1.7 + +package csrf + +import ( + "net/http" + + "github.com/gorilla/context" + + "github.com/pkg/errors" +) + +func contextGet(r *http.Request, key string) (interface{}, error) { + if val, ok := context.GetOk(r, key); ok { + return val, nil + } + + return nil, errors.Errorf("no value exists in the context for key %q", key) +} + +func contextSave(r *http.Request, key string, val interface{}) *http.Request { + context.Set(r, key, val) + return r +} diff --git a/csrf.go b/csrf.go index fe5e933..60e1878 100644 --- a/csrf.go +++ b/csrf.go @@ -174,7 +174,7 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler { // Implements http.Handler for the csrf type. func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Skip the check if directed to. This should always be a bool. - if val, ok := context.GetOk(r, skipCheckKey); ok { + if val, err := contextGet(r, skipCheckKey); err == nil { if skip, ok := val.(bool); ok { if skip { cs.h.ServeHTTP(w, r) @@ -209,9 +209,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Save the masked token to the request context - context.Set(r, tokenKey, mask(realToken, r)) + r = contextSave(r, tokenKey, mask(realToken, r)) // Save the field name to the request context - context.Set(r, formKey, cs.opts.FieldName) + r = contextSave(r, formKey, cs.opts.FieldName) // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. diff --git a/helpers.go b/helpers.go index e0f23e0..7adb5ff 100644 --- a/helpers.go +++ b/helpers.go @@ -16,7 +16,7 @@ import ( // a JSON response body. An empty token will be returned if the middleware // has not been applied (which will fail subsequent validation). func Token(r *http.Request) string { - if val, ok := context.GetOk(r, tokenKey); ok { + if val, err := contextGet(r, tokenKey); err == nil { if maskedToken, ok := val.(string); ok { return maskedToken } @@ -29,7 +29,7 @@ func Token(r *http.Request) string { // This is useful when you want to log the cause of the error or report it to // client. func FailureReason(r *http.Request) error { - if val, ok := context.GetOk(r, errorKey); ok { + if val, err := contextGet(r, errorKey); err == nil { if err, ok := val.(error); ok { return err } @@ -44,8 +44,8 @@ func FailureReason(r *http.Request) error { // Note: You should not set this without otherwise securing the request from // CSRF attacks. The primary use-case for this function is to turn off CSRF // checks for non-browser clients using authorization tokens against your API. -func UnsafeSkipCheck(r *http.Request) { - context.Set(r, skipCheckKey, true) +func UnsafeSkipCheck(r *http.Request) *http.Request { + return contextSave(r, skipCheckKey, true) } // TemplateField is a template helper for html/template that provides an field @@ -60,8 +60,7 @@ func UnsafeSkipCheck(r *http.Request) { // // func TemplateField(r *http.Request) template.HTML { - name, ok := context.GetOk(r, formKey) - if ok { + if name, err := contextGet(r, formKey); err == nil { fragment := fmt.Sprintf(``, name, Token(r)) diff --git a/helpers_test.go b/helpers_test.go index f1340ea..30ee7b3 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -270,7 +270,7 @@ func TestUnsafeSkipCSRFCheck(t *testing.T) { s := http.NewServeMux() skipCheck := func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - UnsafeSkipCheck(r) + r = UnsafeSkipCheck(r) h.ServeHTTP(w, r) }