diff --git a/csrf.go b/csrf.go
index f327f91..6d12cc0 100644
--- a/csrf.go
+++ b/csrf.go
@@ -16,6 +16,7 @@ const tokenLength = 32
// Context/session keys & prefixes
const (
tokenKey string = "gorilla.csrf.Token"
+ formKey string = "gorilla.csrf.Form"
errorKey string = "gorilla.csrf.Error"
cookieName string = "_gorilla_csrf"
errorPrefix string = "gorilla/csrf: "
@@ -198,6 +199,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Save the masked token to the request context
context.Set(r, tokenKey, mask(realToken, r))
+ // Save the field name to the request context
+ context.Set(r, formKey, cs.opts.FieldName)
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
diff --git a/helpers.go b/helpers.go
index 91e3733..fd96b3a 100644
--- a/helpers.go
+++ b/helpers.go
@@ -50,10 +50,15 @@ func FailureReason(r *http.Request) error {
//
//
func TemplateField(r *http.Request) template.HTML {
- fragment := fmt.Sprintf(``,
- fieldName, Token(r))
+ name, ok := context.GetOk(r, formKey)
+ if ok {
+ fragment := fmt.Sprintf(``,
+ name, Token(r))
- return template.HTML(fragment)
+ return template.HTML(fragment)
+ }
+
+ return template.HTML("")
}
// mask returns a unique-per-request token to mitigate the BREACH attack
diff --git a/helpers_test.go b/helpers_test.go
index 73e92e5..ad641f4 100644
--- a/helpers_test.go
+++ b/helpers_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"crypto/rand"
"encoding/base64"
+ "fmt"
"io"
"mime/multipart"
"net/http"
@@ -214,3 +215,42 @@ func TestGenerateRandomBytes(t *testing.T) {
t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b))
}
}
+
+func TestTemplateField(t *testing.T) {
+ s := http.NewServeMux()
+
+ // Make the token & template field available outside of the handler.
+ var token string
+ var templateField string
+ s.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ token = Token(r)
+ templateField = string(TemplateField(r))
+ t := template.Must((template.New("base").Parse(testTemplate)))
+ t.Execute(w, map[string]interface{}{
+ TemplateTag: TemplateField(r),
+ })
+ }))
+
+ testFieldName := "custom_field_name"
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rr := httptest.NewRecorder()
+ p := Protect(testKey, FieldName(testFieldName))(s)
+ p.ServeHTTP(rr, r)
+
+ expectedField := fmt.Sprintf(``,
+ testFieldName, token)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
+ rr.Code, http.StatusOK)
+ }
+
+ if templateField != expectedField {
+ t.Fatalf("custom FieldName was not set correctly: got %v want %v",
+ templateField, expectedField)
+ }
+}