diff --git a/context.go b/context.go
new file mode 100644
index 0000000..fe47270
--- /dev/null
+++ b/context.go
@@ -0,0 +1,25 @@
+// +build go1.7
+
+package csrf
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+func contextGet(r *http.Request, key string) (interface{}, error) {
+ val := r.Context().Value(key)
+ if val == nil {
+ return nil, errors.Errorf("no value exists in the context for key %q", key)
+ }
+
+ return val, nil
+}
+
+func contextSave(r *http.Request, key string, val interface{}) *http.Request {
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, key, val)
+ return r.WithContext(ctx)
+}
diff --git a/context_legacy.go b/context_legacy.go
new file mode 100644
index 0000000..dabf0a6
--- /dev/null
+++ b/context_legacy.go
@@ -0,0 +1,24 @@
+// +build !go1.7
+
+package csrf
+
+import (
+ "net/http"
+
+ "github.com/gorilla/context"
+
+ "github.com/pkg/errors"
+)
+
+func contextGet(r *http.Request, key string) (interface{}, error) {
+ if val, ok := context.GetOk(r, key); ok {
+ return val, nil
+ }
+
+ return nil, errors.Errorf("no value exists in the context for key %q", key)
+}
+
+func contextSave(r *http.Request, key string, val interface{}) *http.Request {
+ context.Set(r, key, val)
+ return r
+}
diff --git a/csrf.go b/csrf.go
index fe5e933..60e1878 100644
--- a/csrf.go
+++ b/csrf.go
@@ -174,7 +174,7 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler {
// Implements http.Handler for the csrf type.
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Skip the check if directed to. This should always be a bool.
- if val, ok := context.GetOk(r, skipCheckKey); ok {
+ if val, err := contextGet(r, skipCheckKey); err == nil {
if skip, ok := val.(bool); ok {
if skip {
cs.h.ServeHTTP(w, r)
@@ -209,9 +209,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Save the masked token to the request context
- context.Set(r, tokenKey, mask(realToken, r))
+ r = contextSave(r, tokenKey, mask(realToken, r))
// Save the field name to the request context
- context.Set(r, formKey, cs.opts.FieldName)
+ r = contextSave(r, formKey, cs.opts.FieldName)
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
diff --git a/helpers.go b/helpers.go
index e0f23e0..7adb5ff 100644
--- a/helpers.go
+++ b/helpers.go
@@ -16,7 +16,7 @@ import (
// a JSON response body. An empty token will be returned if the middleware
// has not been applied (which will fail subsequent validation).
func Token(r *http.Request) string {
- if val, ok := context.GetOk(r, tokenKey); ok {
+ if val, err := contextGet(r, tokenKey); err == nil {
if maskedToken, ok := val.(string); ok {
return maskedToken
}
@@ -29,7 +29,7 @@ func Token(r *http.Request) string {
// This is useful when you want to log the cause of the error or report it to
// client.
func FailureReason(r *http.Request) error {
- if val, ok := context.GetOk(r, errorKey); ok {
+ if val, err := contextGet(r, errorKey); err == nil {
if err, ok := val.(error); ok {
return err
}
@@ -44,8 +44,8 @@ func FailureReason(r *http.Request) error {
// Note: You should not set this without otherwise securing the request from
// CSRF attacks. The primary use-case for this function is to turn off CSRF
// checks for non-browser clients using authorization tokens against your API.
-func UnsafeSkipCheck(r *http.Request) {
- context.Set(r, skipCheckKey, true)
+func UnsafeSkipCheck(r *http.Request) *http.Request {
+ return contextSave(r, skipCheckKey, true)
}
// TemplateField is a template helper for html/template that provides an field
@@ -60,8 +60,7 @@ func UnsafeSkipCheck(r *http.Request) {
//
//
func TemplateField(r *http.Request) template.HTML {
- name, ok := context.GetOk(r, formKey)
- if ok {
+ if name, err := contextGet(r, formKey); err == nil {
fragment := fmt.Sprintf(``,
name, Token(r))
diff --git a/helpers_test.go b/helpers_test.go
index f1340ea..30ee7b3 100644
--- a/helpers_test.go
+++ b/helpers_test.go
@@ -270,7 +270,7 @@ func TestUnsafeSkipCSRFCheck(t *testing.T) {
s := http.NewServeMux()
skipCheck := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
- UnsafeSkipCheck(r)
+ r = UnsafeSkipCheck(r)
h.ServeHTTP(w, r)
}