Skip to content

Commit

Permalink
Add public accessors for request pattern and method (#175)
Browse files Browse the repository at this point in the history
These are very useful values to be able to access easily while processing requests. Let's make them public and reachable via the context.
  • Loading branch information
tomcoupland committed Sep 23, 2024
1 parent ab6ccb7 commit c2c0b01
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
46 changes: 37 additions & 9 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
// https://play.golang.org/p/MxhRiL37R-9
type routerContextKeyType struct{}
type routerRequestPatternContextKeyType struct{}
type routerRequestMethodContextKeyType struct{}

var (
routerContextKey = routerContextKeyType{}
routerRequestPatternContextKey = routerRequestPatternContextKeyType{}
routerRequestMethodContextKey = routerRequestMethodContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
)

Expand Down Expand Up @@ -53,6 +55,22 @@ func routerPathPatternForRequest(r Request) string {
return ""
}

// RequestPatternFromContext returns the pattern that was matched for the request, if available.
func RequestPatternFromContext(ctx context.Context) (string, bool) {
if v := ctx.Value(routerRequestPatternContextKey); v != nil {
return v.(string), true
}
return "", false
}

// RequestMethodFromContext returns the method of the request, if available.
func RequestMethodFromContext(ctx context.Context) (string, bool) {
if v := ctx.Value(routerRequestMethodContextKey); v != nil {
return v.(string), true
}
return "", false
}

func (r *Router) compile(pattern string) *regexp.Regexp {
re, pos := ``, 0
for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) {
Expand Down Expand Up @@ -134,6 +152,7 @@ func (r Router) Serve() Service {
}
req.Context = context.WithValue(req.Context, routerContextKey, &r)
req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern)
req.Context = context.WithValue(req.Context, routerRequestMethodContextKey, req.Method)
rsp := svc(req)
if rsp.Request == nil {
rsp.Request = &req
Expand All @@ -157,37 +176,46 @@ func (r Router) Params(req Request) map[string]string {
// Sugar

// GET is shorthand for:
// r.Register("GET", pattern, svc)
//
// r.Register("GET", pattern, svc)
func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) }

// CONNECT is shorthand for:
// r.Register("CONNECT", pattern, svc)
//
// r.Register("CONNECT", pattern, svc)
func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) }

// DELETE is shorthand for:
// r.Register("DELETE", pattern, svc)
//
// r.Register("DELETE", pattern, svc)
func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) }

// HEAD is shorthand for:
// r.Register("HEAD", pattern, svc)
//
// r.Register("HEAD", pattern, svc)
func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) }

// OPTIONS is shorthand for:
// r.Register("OPTIONS", pattern, svc)
//
// r.Register("OPTIONS", pattern, svc)
func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) }

// PATCH is shorthand for:
// r.Register("PATCH", pattern, svc)
//
// r.Register("PATCH", pattern, svc)
func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) }

// POST is shorthand for:
// r.Register("POST", pattern, svc)
//
// r.Register("POST", pattern, svc)
func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) }

// PUT is shorthand for:
// r.Register("PUT", pattern, svc)
//
// r.Register("PUT", pattern, svc)
func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) }

// TRACE is shorthand for:
// r.Register("TRACE", pattern, svc)
//
// r.Register("TRACE", pattern, svc)
func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) }
22 changes: 22 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,25 @@ func TestRouterSetsRequest(t *testing.T) {
req.Context = rsp.Request.Context
assert.Equal(t, req, *rsp.Request)
}

func TestRouterSetsContextValues(t *testing.T) {
t.Parallel()

router := Router{}
router.GET("/", func(req Request) Response {
return Response{}
})

ctx := context.Background()
req := NewRequest(ctx, "GET", "/", map[string]string{"r": "foo"})
rsp := router.Serve()(req)
require.NotNil(t, rsp.Request)

ctxPattern, ok := RequestPatternFromContext(rsp.Request.Context)
assert.True(t, ok)
assert.Equal(t, "/", ctxPattern)

ctxMethod, ok := RequestMethodFromContext(rsp.Request.Context)
assert.True(t, ok)
assert.Equal(t, "GET", ctxMethod)
}

0 comments on commit c2c0b01

Please sign in to comment.