Skip to content

Commit e0364ca

Browse files
committed
Better handling in middleware for WebSocket
Signed-off-by: Vishal Rana <vr@labstack.com>
1 parent 95f72a5 commit e0364ca

File tree

6 files changed

+69
-41
lines changed

6 files changed

+69
-41
lines changed

echo.go

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020

2121
type (
2222
Echo struct {
23-
Router *router
23+
router *Router
2424
prefix string
2525
middleware []MiddlewareFunc
2626
http2 bool
@@ -33,10 +33,12 @@ type (
3333
pool sync.Pool
3434
debug bool
3535
}
36+
3637
HTTPError struct {
3738
code int
3839
message string
3940
}
41+
4042
Middleware interface{}
4143
MiddlewareFunc func(HandlerFunc) HandlerFunc
4244
Handler interface{}
@@ -99,6 +101,13 @@ const (
99101
ContentLength = "Content-Length"
100102
ContentType = "Content-Type"
101103
Authorization = "Authorization"
104+
Upgrade = "Upgrade"
105+
106+
//-----------
107+
// Protocols
108+
//-----------
109+
110+
WebSocket = "websocket"
102111
)
103112

104113
var (
@@ -122,30 +131,12 @@ var (
122131
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
123132
)
124133

125-
func NewHTTPError(code int, msg ...string) *HTTPError {
126-
he := &HTTPError{code: code, message: http.StatusText(code)}
127-
for _, m := range msg {
128-
he.message = m
129-
}
130-
return he
131-
}
132-
133-
// Code returns code.
134-
func (e *HTTPError) Code() int {
135-
return e.code
136-
}
137-
138-
// Error returns message.
139-
func (e *HTTPError) Error() string {
140-
return e.message
141-
}
142-
143134
// New creates an Echo instance.
144135
func New() (e *Echo) {
145136
e = &Echo{
146137
uris: make(map[Handler]string),
147138
}
148-
e.Router = NewRouter(e)
139+
e.router = NewRouter(e)
149140
e.pool.New = func() interface{} {
150141
return NewContext(nil, new(Response), e)
151142
}
@@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo {
196187
return &g
197188
}
198189

190+
// Router returns router.
191+
func (e *Echo) Router() *Router {
192+
return e.router
193+
}
194+
199195
// HTTP2 enables HTTP2 support.
200196
func (e *Echo) HTTP2(on bool) {
201197
e.http2 = on
@@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) {
302298
func (e *Echo) add(method, path string, h Handler) {
303299
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
304300
e.uris[key] = path
305-
e.Router.Add(method, e.prefix+path, wrapHandler(h), e)
301+
e.router.Add(method, e.prefix+path, wrapHandler(h), e)
306302
}
307303

308304
// Index serves index file.
@@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string {
361357

362358
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
363359
c := e.pool.Get().(*Context)
364-
h, echo := e.Router.Find(r.Method, r.URL.Path, c)
360+
h, echo := e.router.Find(r.Method, r.URL.Path, c)
365361
if echo != nil {
366362
e = echo
367363
}
@@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) {
419415
}
420416
}
421417

418+
func NewHTTPError(code int, msg ...string) *HTTPError {
419+
he := &HTTPError{code: code, message: http.StatusText(code)}
420+
for _, m := range msg {
421+
he.message = m
422+
}
423+
return he
424+
}
425+
426+
// Code returns code.
427+
func (e *HTTPError) Code() int {
428+
return e.code
429+
}
430+
431+
// Error returns message.
432+
func (e *HTTPError) Error() string {
433+
return e.message
434+
}
435+
422436
// wraps middleware
423437
func wrapMiddleware(m Middleware) MiddlewareFunc {
424438
switch m := m.(type) {

group.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package echo
2+
3+
type (
4+
Group struct {
5+
*Echo
6+
}
7+
)

middleware/auth.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ const (
2121
// For invalid credentials, it sends "401 - Unauthorized" response.
2222
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
2323
return func(c *echo.Context) error {
24+
// Skip for WebSocket
25+
if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
26+
return nil
27+
}
28+
2429
auth := c.Request().Header.Get(echo.Authorization)
2530
i := 0
2631
code := http.StatusBadRequest

middleware/compress.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import (
55
"io"
66
"strings"
77

8-
"github.com/labstack/echo"
98
"net/http"
9+
10+
"github.com/labstack/echo"
1011
)
1112

1213
type (
@@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc {
2728

2829
return func(h echo.HandlerFunc) echo.HandlerFunc {
2930
return func(c *echo.Context) error {
30-
if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
31+
if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket
32+
strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
3133
w := gzip.NewWriter(c.Response().Writer())
3234
defer w.Close()
3335
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}

router.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package echo
33
import "net/http"
44

55
type (
6-
router struct {
6+
Router struct {
77
trees map[string]*node
88
echo *Echo
99
}
@@ -27,8 +27,8 @@ const (
2727
mtype
2828
)
2929

30-
func NewRouter(e *Echo) (r *router) {
31-
r = &router{
30+
func NewRouter(e *Echo) (r *Router) {
31+
r = &Router{
3232
trees: make(map[string]*node),
3333
echo: e,
3434
}
@@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) {
4141
return
4242
}
4343

44-
func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
44+
func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) {
4545
var pnames []string // Param names
4646

4747
for i, l := 0, len(path); i < l; i++ {
@@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
7171
r.insert(method, path, h, stype, pnames, echo)
7272
}
7373

74-
func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
74+
func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
7575
cn := r.trees[method] // Current node as root
7676
search := path
7777

@@ -201,7 +201,7 @@ func lcp(a, b string) (i int) {
201201
return
202202
}
203203

204-
func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
204+
func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
205205
cn := r.trees[method] // Current node as root
206206
search := path
207207

@@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
305305
}
306306
}
307307

308-
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
308+
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
309309
c := r.echo.pool.Get().(*Context)
310310
h, _ := r.Find(req.Method, req.URL.Path, c)
311311
c.reset(w, req, r.echo)

router_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ var (
280280
)
281281

282282
func TestRouterStatic(t *testing.T) {
283-
r := New().Router
283+
r := New().router
284284
b := new(bytes.Buffer)
285285
path := "/folders/a/files/echo.gif"
286286
r.Add(GET, path, func(*Context) error {
@@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) {
299299
}
300300

301301
func TestRouterParam(t *testing.T) {
302-
r := New().Router
302+
r := New().router
303303
r.Add(GET, "/users/:id", func(c *Context) error {
304304
return nil
305305
}, nil)
@@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) {
314314
}
315315

316316
func TestRouterTwoParam(t *testing.T) {
317-
r := New().Router
317+
r := New().router
318318
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
319319
return nil
320320
}, nil)
@@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) {
338338
}
339339

340340
func TestRouterMatchAny(t *testing.T) {
341-
r := New().Router
341+
r := New().router
342342
r.Add(GET, "/users/*", func(*Context) error {
343343
return nil
344344
}, nil)
@@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) {
363363
}
364364

365365
func TestRouterMicroParam(t *testing.T) {
366-
r := New().Router
366+
r := New().router
367367
r.Add(GET, "/:a/:b/:c", func(c *Context) error {
368368
return nil
369369
}, nil)
@@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) {
384384
}
385385

386386
func TestRouterMultiRoute(t *testing.T) {
387-
r := New().Router
387+
r := New().router
388388
b := new(bytes.Buffer)
389389

390390
// Routes
@@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) {
425425
}
426426

427427
func TestRouterPriority(t *testing.T) {
428-
r := New().Router
428+
r := New().router
429429

430430
// Routes
431431
r.Add(GET, "/users", func(c *Context) error {
@@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) {
536536
}
537537

538538
func TestRouterParamNames(t *testing.T) {
539-
r := New().Router
539+
r := New().router
540540
b := new(bytes.Buffer)
541541

542542
// Routes
@@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) {
596596
}
597597

598598
func TestRouterAPI(t *testing.T) {
599-
r := New().Router
599+
r := New().router
600600
for _, route := range api {
601601
r.Add(route.method, route.path, func(c *Context) error {
602602
for i, n := range c.pnames {
@@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) {
618618
}
619619

620620
func TestRouterServeHTTP(t *testing.T) {
621-
r := New().Router
621+
r := New().router
622622
r.Add(GET, "/users", func(*Context) error {
623623
return nil
624624
}, nil)

0 commit comments

Comments
 (0)