Skip to content

Commit 485a0fc

Browse files
authored
Prevent cancellation from propagating to background request context (#17)
* Failing test for issue #16 * Cancellation does not propagate to background request context * Updating test description
1 parent b6ab5e1 commit 485a0fc

File tree

3 files changed

+101
-29
lines changed

3 files changed

+101
-29
lines changed

background_request.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package microcache
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
// newBackgroundRequest clones a request for use in background object revalidation.
9+
// This prevents a closed foreground request context from prematurely cancelling
10+
// the background request context.
11+
func newBackgroundRequest(r *http.Request) *http.Request {
12+
return r.Clone(bgContext{r.Context(), make(chan struct{})})
13+
}
14+
15+
type bgContext struct {
16+
context.Context
17+
done chan struct{}
18+
}
19+
20+
func (c bgContext) Done() <-chan struct{} {
21+
return c.done
22+
}

microcache.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ type microcache struct {
4040
collapseMutex *sync.Mutex
4141

4242
// Used to advance time for testing
43-
offset time.Duration
43+
offset time.Duration
44+
offsetMutex *sync.RWMutex
4445
}
4546

4647
type Config struct {
@@ -155,6 +156,7 @@ func New(o Config) *microcache {
155156
revalidateMutex: &sync.Mutex{},
156157
collapse: map[string]*sync.Mutex{},
157158
collapseMutex: &sync.Mutex{},
159+
offsetMutex: &sync.RWMutex{},
158160
}
159161
if o.Driver == nil {
160162
m.Driver = NewDriverLRU(1e4) // default 10k cache items
@@ -180,6 +182,9 @@ func New(o Config) *microcache {
180182
// chain.Append(mx.Middleware)
181183
//
182184
func (m *microcache) Middleware(h http.Handler) http.Handler {
185+
if m.Timeout > 0 {
186+
h = http.TimeoutHandler(h, m.Timeout, "Timed out")
187+
}
183188
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184189
// Websocket passthrough
185190
upgrade := strings.ToLower(r.Header.Get("connection")) == "upgrade"
@@ -259,7 +264,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {
259264
}
260265

261266
// Fresh response object found
262-
if obj.found && obj.expires.After(time.Now().Add(m.offset)) {
267+
if obj.found && obj.expires.After(m.now()) {
263268
if m.Monitor != nil {
264269
m.Monitor.Hit()
265270
}
@@ -273,7 +278,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {
273278

274279
// Stale While Revalidate
275280
if obj.found && req.staleWhileRevalidate > 0 &&
276-
obj.expires.Add(req.staleWhileRevalidate).After(time.Now().Add(m.offset)) {
281+
obj.expires.Add(req.staleWhileRevalidate).After(m.now()) {
277282
if m.Monitor != nil {
278283
m.Monitor.Stale()
279284
}
@@ -291,15 +296,15 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {
291296
}
292297
m.revalidateMutex.Unlock()
293298
if !revalidating {
299+
br := newBackgroundRequest(r)
294300
go func() {
295301
defer func() {
296302
// Clear revalidation lock
297303
m.revalidateMutex.Lock()
298304
delete(m.revalidating, objHash)
299305
m.revalidateMutex.Unlock()
300306
}()
301-
302-
m.handleBackendResponse(h, w, r, reqHash, req, objHash, obj, true)
307+
m.handleBackendResponse(h, w, br, reqHash, req, objHash, obj, true)
303308
}()
304309
}
305310

@@ -329,12 +334,7 @@ func (m *microcache) handleBackendResponse(
329334
beres := Response{header: http.Header{}}
330335

331336
// Execute request
332-
if m.Timeout > 0 {
333-
th := http.TimeoutHandler(h, m.Timeout, "Timed out")
334-
th.ServeHTTP(&beres, r)
335-
} else {
336-
h.ServeHTTP(&beres, r)
337-
}
337+
h.ServeHTTP(&beres, r)
338338

339339
if !beres.headerWritten {
340340
beres.status = http.StatusOK
@@ -347,10 +347,10 @@ func (m *microcache) handleBackendResponse(
347347

348348
// Serve Stale
349349
if beres.status >= 500 && obj.found {
350-
serveStale := obj.expires.Add(req.staleIfError).After(time.Now().Add(m.offset))
350+
serveStale := obj.expires.Add(req.staleIfError).After(m.now())
351351
// Extend stale response expiration by staleIfError grace period
352352
if req.found && serveStale && req.staleRecache {
353-
obj.expires = obj.date.Add(m.offset).Add(req.ttl)
353+
obj.expires = obj.date.Add(m.getOffset()).Add(req.ttl)
354354
m.store(objHash, obj)
355355
}
356356
if !background && serveStale {
@@ -376,7 +376,7 @@ func (m *microcache) handleBackendResponse(
376376
}
377377
// Cache response
378378
if !req.nocache {
379-
beres.expires = time.Now().Add(m.offset).Add(req.ttl)
379+
beres.expires = m.now().Add(req.ttl)
380380
m.store(objHash, beres)
381381
}
382382
}
@@ -418,7 +418,7 @@ func (m *microcache) Start() {
418418
// setAgeHeader sets the age header if not suppressed
419419
func (m *microcache) setAgeHeader(w http.ResponseWriter, obj Response) {
420420
if !m.SuppressAgeHeader {
421-
age := (time.Now().Add(m.offset).Unix() - obj.date.Unix())
421+
age := (m.now().Unix() - obj.date.Unix())
422422
w.Header().Set("age", fmt.Sprintf("%d", age))
423423
}
424424
}
@@ -444,5 +444,19 @@ func (m *microcache) Stop() {
444444

445445
// Increments the offset for testing purposes
446446
func (m *microcache) offsetIncr(o time.Duration) {
447+
m.offsetMutex.Lock()
448+
defer m.offsetMutex.Unlock()
447449
m.offset += o
448450
}
451+
452+
// Get offset
453+
func (m *microcache) getOffset() time.Duration {
454+
m.offsetMutex.RLock()
455+
defer m.offsetMutex.RUnlock()
456+
return m.offset
457+
}
458+
459+
// Get current time with offset
460+
func (m *microcache) now() time.Time {
461+
return time.Now().Add(m.getOffset())
462+
}

microcache_test.go

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package microcache
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/http/httptest"
8+
"strings"
79
"sync"
810
"testing"
911
"time"
@@ -173,26 +175,16 @@ func TestCollapsedFowardingStaleWhileRevalidate(t *testing.T) {
173175
})
174176
defer cache.Stop()
175177
handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler))
176-
batchGet(handler, []string{
177-
"/",
178-
})
178+
batchGet(handler, []string{"/"})
179179
cache.offsetIncr(31 * time.Second)
180180
start := time.Now()
181-
parallelGet(handler, []string{
182-
"/",
183-
"/",
184-
"/",
185-
"/",
186-
"/",
187-
"/",
188-
})
181+
parallelGet(handler, strings.Split(strings.Repeat(",/", 10)[1:], ","))
189182
end := time.Since(start)
190183
// Sleep for a little bit to give the StaleWhileRevalidate goroutines some time to start.
191184
time.Sleep(time.Millisecond * 10)
192-
if testMonitor.getMisses() != 1 || testMonitor.getStales() != 6 ||
185+
if testMonitor.getMisses() != 1 || testMonitor.getStales() != 10 ||
193186
testMonitor.getBackends() != 2 || end > 20*time.Millisecond {
194-
t.Logf("%#v", testMonitor)
195-
t.Fatal("CollapsedFowarding and StaleWhileRevalidate not respected - got", testMonitor.getBackends(), "backend")
187+
t.Fatalf("CollapsedFowarding and StaleWhileRevalidate not respected %s", dumpMonitor(testMonitor))
196188
}
197189
}
198190

@@ -300,6 +292,40 @@ func TestTimeout(t *testing.T) {
300292
}
301293
}
302294

295+
// Request context cancellation should not cause error from TimeoutHandler
296+
func TestRequestContextCancel(t *testing.T) {
297+
testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}}
298+
cache := New(Config{
299+
TTL: 30 * time.Second,
300+
StaleWhileRevalidate: 30 * time.Second,
301+
Timeout: 10 * time.Second,
302+
CollapsedForwarding: true,
303+
Monitor: testMonitor,
304+
Driver: NewDriverLRU(10),
305+
})
306+
defer cache.Stop()
307+
handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler))
308+
batchGet(handler, []string{"/"})
309+
cache.offsetIncr(31 * time.Second)
310+
r, _ := http.NewRequest("GET", "/", nil)
311+
ctx, cancel := context.WithCancel(r.Context())
312+
r = r.WithContext(ctx)
313+
w := httptest.NewRecorder()
314+
handler.ServeHTTP(w, r)
315+
cancel()
316+
time.Sleep(1 * time.Millisecond)
317+
if testMonitor.getErrors() > 0 {
318+
t.Fatal("TimeoutHandler returned error")
319+
}
320+
cache.offsetIncr(31 * time.Second)
321+
cache.Timeout = 1 * time.Millisecond
322+
batchGet(cache.Middleware(http.HandlerFunc(slowSuccessHandler)), []string{"/"})
323+
time.Sleep(2 * time.Millisecond)
324+
if testMonitor.getErrors() != 1 {
325+
t.Fatal("Request did not time out")
326+
}
327+
}
328+
303329
// CollapsedFowarding
304330
func TestCollapsedFowarding(t *testing.T) {
305331
testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}}
@@ -745,3 +771,13 @@ func timelySuccessHandler(w http.ResponseWriter, r *http.Request) {
745771
time.Sleep(10 * time.Millisecond)
746772
http.Error(w, "done", 200)
747773
}
774+
775+
func dumpMonitor(m *monitorFunc) string {
776+
return fmt.Sprintf("Hits: %d, Misses: %d, Backend: %d, Stales: %d, Errors: %d",
777+
m.getHits(),
778+
m.getMisses(),
779+
m.getBackends(),
780+
m.getStales(),
781+
m.getErrors(),
782+
)
783+
}

0 commit comments

Comments
 (0)