Skip to content

Commit

Permalink
Resolves conflicts(gorilla#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkilingr committed Jan 26, 2021
2 parents dde8a3e + d07530f commit 0f80a44
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 174 deletions.
103 changes: 43 additions & 60 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -1,87 +1,70 @@
version: 2.0
version: 2.1

jobs:
# Base test configuration for Go library tests Each distinct version should
# inherit this base, and override (at least) the container image used.
"test": &test
"test":
parameters:
version:
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker:
- image: circleci/golang:latest
- image: "circleci/golang:<< parameters.version >>"
working_directory: /go/src/github.com/gorilla/mux
steps: &steps
# Our build steps: we checkout the repo, fetch our deps, lint, and finally
# run "go test" on the package.
environment:
GO111MODULE: "on"
GOPROXY: "<< parameters.goproxy >>"
steps:
- checkout
# Logs the version in our build logs, for posterity
- run: go version
- run:
name: "Print the Go version"
command: >
go version
- run:
name: "Fetch dependencies"
command: >
go get -t -v ./...
if [[ << parameters.modules >> = true ]]; then
go mod download
export GO111MODULE=on
else
go get -v ./...
fi
# Only run gofmt, vet & lint against the latest Go version
- run:
name: "Run golint"
command: >
if [ "${LATEST}" = true ] && [ -z "${SKIP_GOLINT}" ]; then
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run:
name: "Run gofmt"
command: >
if [[ "${LATEST}" = true ]]; then
if [[ << parameters.version >> = "latest" ]]; then
diff -u <(echo -n) <(gofmt -d -e .)
fi
- run:
name: "Run go vet"
command: >
if [[ "${LATEST}" = true ]]; then
command: >
if [[ << parameters.version >> = "latest" ]]; then
go vet -v ./...
fi
- run: go test -v -race ./...

"latest":
<<: *test
environment:
LATEST: true

"1.12":
<<: *test
docker:
- image: circleci/golang:1.12

"1.11":
<<: *test
docker:
- image: circleci/golang:1.11

"1.10":
<<: *test
docker:
- image: circleci/golang:1.10

"1.9":
<<: *test
docker:
- image: circleci/golang:1.9

"1.8":
<<: *test
docker:
- image: circleci/golang:1.8

"1.7":
<<: *test
docker:
- image: circleci/golang:1.7
- run:
name: "Run go test (+ race detector)"
command: >
go test -v -race ./...
workflows:
version: 2
build:
tests:
jobs:
- "latest"
- "1.12"
- "1.11"
- "1.10"
- "1.9"
- "1.8"
- "1.7"
- test:
matrix:
parameters:
version: ["latest", "1.15", "1.14", "1.13", "1.12", "1.11"]
18 changes: 0 additions & 18 deletions context.go

This file was deleted.

30 changes: 0 additions & 30 deletions context_test.go

This file was deleted.

25 changes: 10 additions & 15 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,17 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string

err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}

allMethods = append(allMethods, methods...)
}
break
for _, route := range r.routes {
var match RouteMatch
if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch {
methods, err := route.GetMethods()
if err != nil {
return nil, err
}

allMethods = append(allMethods, methods...)
}
return nil
})
}

return allMethods, err
return allMethods, nil
}
20 changes: 20 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,26 @@ func TestCORSMethodMiddleware(t *testing.T) {
}
}

func TestCORSMethodMiddlewareSubrouter(t *testing.T) {
router := NewRouter().StrictSlash(true)

subrouter := router.PathPrefix("/test").Subrouter()
subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost)
subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions)

subrouter.Use(CORSMethodMiddleware(subrouter))

rw := NewRecorder()
req := newRequest("GET", "/test/hello/asdf")
router.ServeHTTP(rw, req)

actualMethods := rw.Header().Get("Access-Control-Allow-Methods")
expectedMethods := "GET,OPTIONS"
if actualMethods != expectedMethods {
t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods)
}
}

func TestMiddlewareOnMultiSubrouter(t *testing.T) {
first := "first"
second := "second"
Expand Down
25 changes: 13 additions & 12 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package mux

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -58,8 +59,7 @@ type Router struct {

// If true, do not clear the request context after handling the request.
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
// Deprecated: No effect, since the context is stored on the request itself.
KeepContext bool

// Slice of middlewares to be called after a match is found
Expand Down Expand Up @@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
req = requestWithVars(req, match.Vars)
req = requestWithRoute(req, match.Route)
}

if handler == nil && match.MatchErr == ErrMethodMismatch {
Expand Down Expand Up @@ -426,7 +426,7 @@ const (

// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil {
if rv := r.Context().Value(varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
Expand All @@ -435,21 +435,22 @@ func Vars(r *http.Request) map[string]string {
// CurrentRoute returns the matched route for the current request, if any.
// This only works when called inside the handler of the matched route
// because the matched route is stored in the request context which is cleared
// after the handler returns, unless the KeepContext option is set on the
// Router.
// after the handler returns.
func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil {
if rv := r.Context().Value(routeKey); rv != nil {
return rv.(*Route)
}
return nil
}

func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
ctx := context.WithValue(r.Context(), varsKey, vars)
return r.WithContext(ctx)
}

func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
func requestWithRoute(r *http.Request, route *Route) *http.Request {
ctx := context.WithValue(r.Context(), routeKey, route)
return r.WithContext(ctx)
}

// ----------------------------------------------------------------------------
Expand Down
49 changes: 49 additions & 0 deletions mux_httpserver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// +build go1.9

package mux

import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestSchemeMatchers(t *testing.T) {
router := NewRouter()
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello http world"))
}).Schemes("http")
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello https world"))
}).Schemes("https")

assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected error reading body: %v", err)
}
if !bytes.Equal(body, []byte(expectedBody)) {
t.Fatalf("response should be hello world, was: %q", string(body))
}
}

t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(router)
defer s.Close()
assertResponseBody(t, s, "hello http world")
})
t.Run("httpsServer", func(t *testing.T) {
s := httptest.NewTLSServer(router)
defer s.Close()
assertResponseBody(t, s, "hello https world")
})
}
Loading

0 comments on commit 0f80a44

Please sign in to comment.