From 4b6cefec207337f4a0e27d2dc43ce613f04c100d Mon Sep 17 00:00:00 2001 From: Matt Silverlock Date: Mon, 30 Nov 2015 09:31:02 +0800 Subject: [PATCH 1/2] [feature] Custom field names are now passed to TemplateField implicitly. --- csrf.go | 3 +++ helpers.go | 11 ++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/csrf.go b/csrf.go index f327f91..6d12cc0 100644 --- a/csrf.go +++ b/csrf.go @@ -16,6 +16,7 @@ const tokenLength = 32 // Context/session keys & prefixes const ( tokenKey string = "gorilla.csrf.Token" + formKey string = "gorilla.csrf.Form" errorKey string = "gorilla.csrf.Error" cookieName string = "_gorilla_csrf" errorPrefix string = "gorilla/csrf: " @@ -198,6 +199,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Save the masked token to the request context context.Set(r, tokenKey, mask(realToken, r)) + // Save the field name to the request context + context.Set(r, formKey, cs.opts.FieldName) // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. diff --git a/helpers.go b/helpers.go index 91e3733..fd96b3a 100644 --- a/helpers.go +++ b/helpers.go @@ -50,10 +50,15 @@ func FailureReason(r *http.Request) error { // // func TemplateField(r *http.Request) template.HTML { - fragment := fmt.Sprintf(``, - fieldName, Token(r)) + name, ok := context.GetOk(r, formKey) + if ok { + fragment := fmt.Sprintf(``, + name, Token(r)) - return template.HTML(fragment) + return template.HTML(fragment) + } + + return template.HTML("") } // mask returns a unique-per-request token to mitigate the BREACH attack From 12a0c0e84f7227a7c9f9a64a62051a83cb144bcb Mon Sep 17 00:00:00 2001 From: Matt Silverlock Date: Mon, 30 Nov 2015 09:39:57 +0800 Subject: [PATCH 2/2] [tests] Added test for TemplateField using custom FieldNames. --- helpers_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/helpers_test.go b/helpers_test.go index 73e92e5..ad641f4 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "encoding/base64" + "fmt" "io" "mime/multipart" "net/http" @@ -214,3 +215,42 @@ func TestGenerateRandomBytes(t *testing.T) { t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b)) } } + +func TestTemplateField(t *testing.T) { + s := http.NewServeMux() + + // Make the token & template field available outside of the handler. + var token string + var templateField string + s.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token = Token(r) + templateField = string(TemplateField(r)) + t := template.Must((template.New("base").Parse(testTemplate))) + t.Execute(w, map[string]interface{}{ + TemplateTag: TemplateField(r), + }) + })) + + testFieldName := "custom_field_name" + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + p := Protect(testKey, FieldName(testFieldName))(s) + p.ServeHTTP(rr, r) + + expectedField := fmt.Sprintf(``, + testFieldName, token) + + if rr.Code != http.StatusOK { + t.Fatalf("middleware failed to pass to the next handler: got %v want %v", + rr.Code, http.StatusOK) + } + + if templateField != expectedField { + t.Fatalf("custom FieldName was not set correctly: got %v want %v", + templateField, expectedField) + } +}