diff --git a/csrf.go b/csrf.go index f327f91..6d12cc0 100644 --- a/csrf.go +++ b/csrf.go @@ -16,6 +16,7 @@ const tokenLength = 32 // Context/session keys & prefixes const ( tokenKey string = "gorilla.csrf.Token" + formKey string = "gorilla.csrf.Form" errorKey string = "gorilla.csrf.Error" cookieName string = "_gorilla_csrf" errorPrefix string = "gorilla/csrf: " @@ -198,6 +199,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Save the masked token to the request context context.Set(r, tokenKey, mask(realToken, r)) + // Save the field name to the request context + context.Set(r, formKey, cs.opts.FieldName) // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. diff --git a/helpers.go b/helpers.go index 91e3733..fd96b3a 100644 --- a/helpers.go +++ b/helpers.go @@ -50,10 +50,15 @@ func FailureReason(r *http.Request) error { // // func TemplateField(r *http.Request) template.HTML { - fragment := fmt.Sprintf(``, - fieldName, Token(r)) + name, ok := context.GetOk(r, formKey) + if ok { + fragment := fmt.Sprintf(``, + name, Token(r)) - return template.HTML(fragment) + return template.HTML(fragment) + } + + return template.HTML("") } // mask returns a unique-per-request token to mitigate the BREACH attack diff --git a/helpers_test.go b/helpers_test.go index 73e92e5..ad641f4 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "encoding/base64" + "fmt" "io" "mime/multipart" "net/http" @@ -214,3 +215,42 @@ func TestGenerateRandomBytes(t *testing.T) { t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b)) } } + +func TestTemplateField(t *testing.T) { + s := http.NewServeMux() + + // Make the token & template field available outside of the handler. + var token string + var templateField string + s.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token = Token(r) + templateField = string(TemplateField(r)) + t := template.Must((template.New("base").Parse(testTemplate))) + t.Execute(w, map[string]interface{}{ + TemplateTag: TemplateField(r), + }) + })) + + testFieldName := "custom_field_name" + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + p := Protect(testKey, FieldName(testFieldName))(s) + p.ServeHTTP(rr, r) + + expectedField := fmt.Sprintf(``, + testFieldName, token) + + if rr.Code != http.StatusOK { + t.Fatalf("middleware failed to pass to the next handler: got %v want %v", + rr.Code, http.StatusOK) + } + + if templateField != expectedField { + t.Fatalf("custom FieldName was not set correctly: got %v want %v", + templateField, expectedField) + } +}