Skip to content

Commit 3dd5c4d

Browse files
committed
Prevent multiple Set-Cookie headers when calling RegenerateToken
Closes #61
1 parent ec1bc1f commit 3dd5c4d

6 files changed

+95
-16
lines changed

context.go

+17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ type csrfContext struct {
1616
token string
1717
// reason for the failure of CSRF check
1818
reason error
19+
// wasSent is true if `Set-Cookie` was called
20+
// for the `name=csrf_token` already. This prevents
21+
// duplicate `Set-Cookie: csrf_token` headers.
22+
// For more information see:
23+
// https://github.com/justinas/nosurf/pull/61
24+
wasSent bool
1925
}
2026

2127
// Token takes an HTTP request and returns
@@ -53,6 +59,17 @@ func ctxSetToken(req *http.Request, token []byte) {
5359
ctx.token = b64encode(maskToken(token))
5460
}
5561

62+
func ctxSetSent(req *http.Request) {
63+
ctx := req.Context().Value(nosurfKey).(*csrfContext)
64+
ctx.wasSent = true
65+
}
66+
67+
func ctxWasSent(req *http.Request) bool {
68+
ctx := req.Context().Value(nosurfKey).(*csrfContext)
69+
70+
return ctx.wasSent
71+
}
72+
5673
func ctxSetReason(req *http.Request, reason error) {
5774
ctx := req.Context().Value(nosurfKey).(*csrfContext)
5875
if ctx.token == "" {

context_legacy.go

+32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ type csrfContext struct {
1717
token string
1818
// reason for the failure of CSRF check
1919
reason error
20+
// wasSent is true if `Set-Cookie` was called
21+
// for the `name=csrf_token` already. This prevents
22+
// duplicate `Set-Cookie: csrf_token` headers.
23+
// For more information see:
24+
// https://github.com/justinas/nosurf/pull/61
25+
wasSent bool
2026
}
2127

2228
var (
@@ -79,6 +85,32 @@ func ctxSetToken(req *http.Request, token []byte) *http.Request {
7985
return req
8086
}
8187

88+
func ctxSetSent(req *http.Request) {
89+
cmMutex.Lock()
90+
defer cmMutex.Unlock()
91+
92+
ctx, ok := contextMap[req]
93+
if !ok {
94+
ctx = new(csrfContext)
95+
contextMap[req] = ctx
96+
}
97+
98+
ctx.wasSent = true
99+
}
100+
101+
func ctxWasSent(req *http.Request) bool {
102+
cmMutex.RLock()
103+
defer cmMutex.RUnlock()
104+
105+
ctx, ok := contextMap[req]
106+
107+
if !ok {
108+
return false
109+
}
110+
111+
return ctx.wasSent
112+
}
113+
82114
func ctxSetReason(req *http.Request, reason error) *http.Request {
83115
cmMutex.Lock()
84116
defer cmMutex.Unlock()

handler.go

+11
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ func (h *CSRFHandler) handleFailure(w http.ResponseWriter, r *http.Request) {
195195

196196
// Generates a new token, sets it on the given request and returns it
197197
func (h *CSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string {
198+
if ctxWasSent(r) {
199+
// The CSRF Cookie was set already by an earlier call to `RegenerateToken`
200+
// in the same request context. It therefore does not make sense to regenerate
201+
// it again as it will lead to two or more `Set-Cookie` instructions which will in turn
202+
// cause CSRF to fail depending on the resulting order of the `Set-Cookie` instructions.
203+
//
204+
// No warning is necessary as the only caller to `setTokenCookie` is `RegenerateToken`.
205+
return Token(r)
206+
}
207+
198208
token := generateToken()
199209
h.setTokenCookie(w, r, token)
200210

@@ -210,6 +220,7 @@ func (h *CSRFHandler) setTokenCookie(w http.ResponseWriter, r *http.Request, tok
210220
cookie.Value = b64encode(token)
211221

212222
http.SetCookie(w, &cookie)
223+
ctxSetSent(r)
213224

214225
}
215226

handler_go17_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,20 @@ func TestContextIsAccessibleWithContext(t *testing.T) {
2828

2929
hand.ServeHTTP(writer, req)
3030
}
31+
32+
func TestNoDoubleCookie(t *testing.T) {
33+
var n *CSRFHandler
34+
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35+
n.RegenerateToken(w, r)
36+
}))
37+
38+
r := httptest.NewRequest("GET", "http://dummy.us", nil)
39+
w := httptest.NewRecorder()
40+
41+
n.ServeHTTP(w, r)
42+
43+
count := len(w.Result().Cookies())
44+
if count > 1 {
45+
t.Errorf("Expected one CSRF cookie, got %d", count)
46+
}
47+
}

handler_legacy_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package nosurf
55
import (
66
"net/http"
77
"net/http/httptest"
8+
"strings"
89
"testing"
910
)
1011

@@ -20,3 +21,20 @@ func TestClearsContextAfterTheRequest(t *testing.T) {
2021
t.Errorf("Instead, the context entry remains: %v", contextMap[req])
2122
}
2223
}
24+
25+
func TestNoDoubleCookie(t *testing.T) {
26+
var n *CSRFHandler
27+
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28+
n.RegenerateToken(w, r)
29+
}))
30+
31+
r := httptest.NewRequest("GET", "http://dummy.us", nil)
32+
w := httptest.NewRecorder()
33+
34+
n.ServeHTTP(w, r)
35+
36+
count := strings.Count(w.HeaderMap.Get("Set-Cookie"), "csrf_token")
37+
if count > 1 {
38+
t.Errorf("Expected one CSRF cookie, got %d", count)
39+
}
40+
}

handler_test.go

-16
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,6 @@ import (
99
"testing"
1010
)
1111

12-
func TestNoDoubleCookie(t *testing.T) {
13-
var n *CSRFHandler
14-
n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15-
n.RegenerateToken(w, r)
16-
}))
17-
18-
r := httptest.NewRequest("GET", "http://dummy.us", nil)
19-
w := httptest.NewRecorder()
20-
21-
n.ServeHTTP(w, r)
22-
23-
if len(w.Result().Cookies()) > 1 {
24-
t.Errorf("Expected one CSRF cookie, got %d", len(w.Result().Cookies()))
25-
}
26-
}
27-
2812
func TestDefaultFailureHandler(t *testing.T) {
2913
writer := httptest.NewRecorder()
3014
req := dummyGet()

0 commit comments

Comments
 (0)