Skip to content

Commit 4c69c58

Browse files
authored
Merge pull request #47 from LyricTian/develop
optimization of error handling
2 parents ecbcbe6 + b2d52f6 commit 4c69c58

File tree

6 files changed

+88
-31
lines changed

6 files changed

+88
-31
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343
"log"
4444
"net/http"
4545

46+
"gopkg.in/oauth2.v3/errors"
4647
"gopkg.in/oauth2.v3/manage"
4748
"gopkg.in/oauth2.v3/models"
4849
"gopkg.in/oauth2.v3/server"
@@ -67,8 +68,13 @@ func main() {
6768
srv.SetAllowGetAccessRequest(true)
6869
srv.SetClientInfoHandler(server.ClientFormHandler)
6970

70-
srv.SetInternalErrorHandler(func(err error) {
71-
log.Println("OAuth2 Error:", err.Error())
71+
srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
72+
log.Println("Internal Error:", err.Error())
73+
return
74+
})
75+
76+
srv.SetResponseErrorHandler(func(re *errors.Response) {
77+
log.Println("Response Error:", re.Error.Error())
7278
})
7379

7480
http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
@@ -82,8 +88,9 @@ func main() {
8288
srv.HandleTokenRequest(w, r)
8389
})
8490

85-
http.ListenAndServe(":9096", nil)
91+
log.Fatal(http.ListenAndServe(":9096", nil))
8692
}
93+
8794
```
8895

8996
### Build and run
@@ -139,8 +146,8 @@ Copyright (c) 2016 Lyric
139146
[License-Image]: https://img.shields.io/npm/l/express.svg
140147
[Build-Status-Url]: https://travis-ci.org/go-oauth2/oauth2
141148
[Build-Status-Image]: https://travis-ci.org/go-oauth2/oauth2.svg?branch=master
142-
[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.6.3
143-
[Release-image]: http://img.shields.io/badge/release-v3.6.3-1eb0fc.svg
149+
[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.7.0
150+
[Release-image]: http://img.shields.io/badge/release-v3.7.0-1eb0fc.svg
144151
[ReportCard-Url]: https://goreportcard.com/report/gopkg.in/oauth2.v3
145152
[ReportCard-Image]: https://goreportcard.com/badge/gopkg.in/oauth2.v3
146153
[GoDoc-Url]: https://godoc.org/gopkg.in/oauth2.v3

errors/response.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package errors
22

3-
import "errors"
3+
import (
4+
"errors"
5+
"net/http"
6+
)
47

58
// Response error response
69
type Response struct {
@@ -9,6 +12,24 @@ type Response struct {
912
Description string
1013
URI string
1114
StatusCode int
15+
Header http.Header
16+
}
17+
18+
// NewResponse create the response pointer
19+
func NewResponse(err error, statusCode int) *Response {
20+
return &Response{
21+
Error: err,
22+
StatusCode: statusCode,
23+
}
24+
}
25+
26+
// SetHeader sets the header entries associated with key to
27+
// the single element value.
28+
func (r *Response) SetHeader(key, value string) {
29+
if r.Header == nil {
30+
r.Header = make(http.Header)
31+
}
32+
r.Header.Set(key, value)
1233
}
1334

1435
// https://tools.ietf.org/html/rfc6749#section-5.2

example/server/server.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/url"
77
"os"
88

9+
"gopkg.in/oauth2.v3/errors"
910
"gopkg.in/oauth2.v3/manage"
1011
"gopkg.in/oauth2.v3/models"
1112
"gopkg.in/oauth2.v3/server"
@@ -37,8 +38,14 @@ func main() {
3738

3839
srv := server.NewServer(server.NewConfig(), manager)
3940
srv.SetUserAuthorizationHandler(userAuthorizeHandler)
40-
srv.SetInternalErrorHandler(func(err error) {
41-
log.Println("[oauth2] error:", err.Error())
41+
42+
srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
43+
log.Println("Internal Error:", err.Error())
44+
return
45+
})
46+
47+
srv.SetResponseErrorHandler(func(re *errors.Response) {
48+
log.Println("Response Error:", re.Error.Error())
4249
})
4350

4451
http.HandleFunc("/login", loginHandler)

server/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ type (
2828
RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error)
2929

3030
// ResponseErrorHandler response error handing
31-
ResponseErrorHandler func(err error) (re *errors.Response)
31+
ResponseErrorHandler func(re *errors.Response)
3232

3333
// InternalErrorHandler internal error handing
34-
InternalErrorHandler func(err error)
34+
InternalErrorHandler func(err error) (re *errors.Response)
3535

3636
// AuthorizeScopeHandler set the authorized scope
3737
AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)

server/server.go

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err
6565
uerr = err
6666
return
6767
}
68-
data, _ := s.GetErrorData(err)
68+
data, _, _ := s.GetErrorData(err)
6969
err = s.redirect(w, req, data)
7070
return
7171
}
@@ -81,16 +81,21 @@ func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map
8181
}
8282

8383
func (s *Server) tokenError(w http.ResponseWriter, err error) (uerr error) {
84-
data, statusCode := s.GetErrorData(err)
85-
uerr = s.token(w, data, statusCode)
84+
data, statusCode, header := s.GetErrorData(err)
85+
86+
uerr = s.token(w, data, header, statusCode)
8687
return
8788
}
8889

89-
func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, statusCode ...int) (err error) {
90+
func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) (err error) {
9091
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
9192
w.Header().Set("Cache-Control", "no-store")
9293
w.Header().Set("Pragma", "no-cache")
9394

95+
for key := range header {
96+
w.Header().Set(key, header.Get(key))
97+
}
98+
9499
status := http.StatusOK
95100
if len(statusCode) > 0 && statusCode[0] > 0 {
96101
status = statusCode[0]
@@ -490,38 +495,45 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
490495
err = s.tokenError(w, verr)
491496
return
492497
}
493-
err = s.token(w, s.GetTokenData(ti))
498+
499+
err = s.token(w, s.GetTokenData(ti), nil)
494500
return
495501
}
496502

497503
// GetErrorData get error response data
498-
func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCode int) {
499-
re := &errors.Response{}
504+
func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCode int, header http.Header) {
505+
re := new(errors.Response)
500506

501507
if v, ok := errors.Descriptions[err]; ok {
502508
re.Error = err
503509
re.Description = v
504510
re.StatusCode = errors.StatusCodes[err]
505511
} else {
506512
if fn := s.InternalErrorHandler; fn != nil {
507-
fn(err)
513+
if vre := fn(err); vre != nil {
514+
re = vre
515+
}
516+
}
517+
518+
if re.Error == nil {
519+
re.Error = errors.ErrServerError
520+
re.Description = errors.Descriptions[errors.ErrServerError]
521+
re.StatusCode = errors.StatusCodes[errors.ErrServerError]
508522
}
509523
}
510524

511525
if fn := s.ResponseErrorHandler; fn != nil {
512-
if vre := fn(err); vre != nil {
513-
re = vre
526+
fn(re)
527+
528+
if re == nil {
529+
re = new(errors.Response)
514530
}
515531
}
516532

517-
if re.Error == nil {
518-
re.Error = errors.ErrServerError
519-
re.Description = errors.Descriptions[errors.ErrServerError]
520-
re.StatusCode = errors.StatusCodes[errors.ErrServerError]
521-
}
533+
data = make(map[string]interface{})
522534

523-
data = map[string]interface{}{
524-
"error": re.Error.Error(),
535+
if err := re.Error; err != nil {
536+
data["error"] = err.Error()
525537
}
526538

527539
if v := re.ErrorCode; v != 0 {
@@ -536,11 +548,13 @@ func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCod
536548
data["error_uri"] = v
537549
}
538550

539-
statusCode = 400
551+
header = re.Header
540552

553+
statusCode = http.StatusInternalServerError
541554
if v := re.StatusCode; v > 0 {
542555
statusCode = v
543556
}
557+
544558
return
545559
}
546560

server/server_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package server_test
22

33
import (
4-
"errors"
4+
"fmt"
55
"net/http"
66
"net/http/httptest"
77
"net/url"
88
"testing"
99

1010
"github.com/gavv/httpexpect"
1111
"gopkg.in/oauth2.v3"
12+
"gopkg.in/oauth2.v3/errors"
1213
"gopkg.in/oauth2.v3/manage"
1314
"gopkg.in/oauth2.v3/models"
1415
"gopkg.in/oauth2.v3/server"
@@ -144,7 +145,7 @@ func TestPasswordCredentials(t *testing.T) {
144145
userID = "000000"
145146
return
146147
}
147-
err = errors.New("user not found")
148+
err = fmt.Errorf("user not found")
148149
return
149150
})
150151

@@ -174,9 +175,16 @@ func TestClientCredentials(t *testing.T) {
174175

175176
srv = server.NewDefaultServer(manager)
176177
srv.SetClientInfoHandler(server.ClientFormHandler)
177-
srv.SetInternalErrorHandler(func(err error) {
178+
179+
srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
178180
t.Log("OAuth 2.0 Error:", err.Error())
181+
return
179182
})
183+
184+
srv.SetResponseErrorHandler(func(re *errors.Response) {
185+
t.Log("Response Error:", re.Error)
186+
})
187+
180188
srv.SetAllowedGrantType(oauth2.ClientCredentials)
181189
srv.SetAllowGetAccessRequest(false)
182190
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {

0 commit comments

Comments
 (0)