Skip to content

Commit

Permalink
Retry middleware : store headers per attempts and propagate them when…
Browse files Browse the repository at this point in the history
… responding.
  • Loading branch information
jlevesy authored and traefiker committed Jan 8, 2019
1 parent d7bd697 commit fc8c24e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 13 deletions.
20 changes: 14 additions & 6 deletions middlewares/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

attempts := 1
for {
attemptsExhausted := attempts >= r.attempts
shouldRetry := !attemptsExhausted
shouldRetry := attempts < r.attempts
retryResponseWriter := newResponseWriter(rw, shouldRetry)

// Disable retries when the backend already received request data
Expand Down Expand Up @@ -118,6 +117,7 @@ type responseWriter interface {
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
responseWriter := &responseWriterWithoutCloseNotify{
responseWriter: rw,
headers: make(http.Header),
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
Expand All @@ -130,6 +130,7 @@ func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter

type responseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
headers http.Header
shouldRetry bool
}

Expand All @@ -142,10 +143,7 @@ func (r *responseWriterWithoutCloseNotify) DisableRetries() {
}

func (r *responseWriterWithoutCloseNotify) Header() http.Header {
if r.ShouldRetry() {
return make(http.Header)
}
return r.responseWriter.Header()
return r.headers
}

func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
Expand All @@ -168,6 +166,16 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
if r.ShouldRetry() {
return
}

// In that case retry case is set to false which means we at least managed
// to write headers to the backend : we are not going to perform any further retry.
// So it is now safe to alter current response headers with headers collected during
// the latest try before writing headers to client.
headers := r.responseWriter.Header()
for header, value := range r.headers {
headers[header] = value
}

r.responseWriter.WriteHeader(code)
}

Expand Down
46 changes: 46 additions & 0 deletions middlewares/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package retry

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strings"
"testing"

Expand Down Expand Up @@ -149,6 +151,50 @@ func TestRetryListeners(t *testing.T) {
}
}

func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
attempt := 0
expectedHeaderName := "X-Foo-Test-2"
expectedHeaderValue := "bar"

next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
rw.Header().Add(headerName, expectedHeaderValue)
if attempt < 2 {
attempt++
return
}

// Request has been successfully written to backend
trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders()

// And we decide to answer to client
rw.WriteHeader(http.StatusNoContent)
})

retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)

responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))

headerValue := responseRecorder.Header().Get(expectedHeaderName)

// Validate if we have the correct header
if headerValue != expectedHeaderValue {
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
}

// Validate that we don't have headers from previous attempts
for i := 0; i < attempt; i++ {
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
headerValue = responseRecorder.Header().Get("headerName")
if headerValue != "" {
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
}
}
}

// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
Expand Down
21 changes: 14 additions & 7 deletions old/middlewares/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

attempts := 1
for {
attemptsExhausted := attempts >= retry.attempts

shouldRetry := !attemptsExhausted
shouldRetry := attempts < retry.attempts
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)

// Disable retries when the backend already received request data
Expand Down Expand Up @@ -99,6 +97,7 @@ type retryResponseWriter interface {
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
responseWriter := &retryResponseWriterWithoutCloseNotify{
responseWriter: rw,
headers: make(http.Header),
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
Expand All @@ -109,6 +108,7 @@ func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryRespo

type retryResponseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
headers http.Header
shouldRetry bool
}

Expand All @@ -121,10 +121,7 @@ func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
}

func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
if rr.ShouldRetry() {
return make(http.Header)
}
return rr.responseWriter.Header()
return rr.headers
}

func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
Expand All @@ -147,6 +144,16 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
if rr.ShouldRetry() {
return
}

// In that case retry case is set to false which means we at least managed
// to write headers to the backend : we are not going to perform any further retry.
// So it is now safe to alter current response headers with headers collected during
// the latest try before writing headers to client.
headers := rr.responseWriter.Header()
for header, value := range rr.headers {
headers[header] = value
}

rr.responseWriter.WriteHeader(code)
}

Expand Down
44 changes: 44 additions & 0 deletions old/middlewares/retry_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package middlewares

import (
"fmt"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strings"
"testing"

Expand Down Expand Up @@ -258,3 +260,45 @@ func TestRetryWithFlush(t *testing.T) {
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
}
}

func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
attempt := 0
expectedHeaderName := "X-Foo-Test-2"
expectedHeaderValue := "bar"

next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
rw.Header().Add(headerName, expectedHeaderValue)
if attempt < 2 {
attempt++
return
}

// Request has been successfully written to backend
trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders()

// And we decide to answer to client
rw.WriteHeader(http.StatusNoContent)
})

retry := NewRetry(3, next, &countingRetryListener{})
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})

headerValue := responseRecorder.Header().Get(expectedHeaderName)

// Validate if we have the correct header
if headerValue != expectedHeaderValue {
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
}

// Validate that we don't have headers from previous attempts
for i := 0; i < attempt; i++ {
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
headerValue = responseRecorder.Header().Get("headerName")
if headerValue != "" {
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
}
}
}

0 comments on commit fc8c24e

Please sign in to comment.