Skip to content

Commit becfcf0

Browse files
authored
Merge pull request #19 from LyricTian/develop
Fix error handling
2 parents 69079cf + 217243c commit becfcf0

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

server/handler.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package server
22

33
import (
44
"net/http"
5-
65
"time"
76

87
"gopkg.in/oauth2.v3"
@@ -25,7 +24,7 @@ type UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (user
2524
type PasswordAuthorizationHandler func(username, password string) (userID string, err error)
2625

2726
// RefreshingScopeHandler Check the scope of the refreshing token
28-
type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool)
27+
type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error)
2928

3029
// ResponseErrorHandler Response error handing
3130
type ResponseErrorHandler func(re *errors.Response)
@@ -34,13 +33,13 @@ type ResponseErrorHandler func(re *errors.Response)
3433
type InternalErrorHandler func(r *http.Request, err error)
3534

3635
// AuthorizeScopeHandler Set the authorized scope
37-
type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string)
36+
type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)
3837

3938
// AccessTokenExpHandler Set expiration date for the access token
40-
type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration)
39+
type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error)
4140

4241
// ExtensionFieldsHandler In response to the access token with the extension of the field
43-
type ExtensionFieldsHandler func(w http.ResponseWriter, r *http.Request) (fieldsValue map[string]interface{})
42+
type ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})
4443

4544
// ClientFormHandler Get client data from form
4645
func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) {

server/server.go

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,21 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
322322
req.UserID = userID
323323
// specify the scope of authorization
324324
if fn := s.AuthorizeScopeHandler; fn != nil {
325-
if scope := fn(w, r); scope != "" {
325+
scope, verr := fn(w, r)
326+
if verr != nil {
327+
err = verr
328+
return
329+
} else if scope != "" {
326330
req.Scope = scope
327331
}
328332
}
329333
// specify the expiration time of access token
330334
if fn := s.AccessTokenExpHandler; fn != nil {
331-
if exp := fn(w, r); exp > 0 {
335+
exp, verr := fn(w, r)
336+
if verr != nil {
337+
err = verr
338+
return
339+
} else if exp > 0 {
332340
req.AccessTokenExp = exp
333341
}
334342
}
@@ -403,8 +411,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
403411
if err != nil {
404412
ierr = err
405413
return
406-
}
407-
if !allowed {
414+
} else if !allowed {
408415
rerr = errors.ErrUnauthorizedClient
409416
return
410417
}
@@ -427,8 +434,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
427434
if err != nil {
428435
ierr = err
429436
return
430-
}
431-
if !allowed {
437+
} else if !allowed {
432438
rerr = errors.ErrInvalidScope
433439
return
434440
}
@@ -441,17 +447,23 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
441447
}
442448
}
443449
case oauth2.Refreshing:
444-
if scope := tgr.Scope; scope != "" {
450+
// check scope
451+
if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
445452
rti, err := s.Manager.LoadRefreshToken(tgr.Refresh)
446453
if err != nil {
447-
if err == errors.ErrInvalidRefreshToken {
448-
rerr = err
454+
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
455+
rerr = errors.ErrInvalidGrant
449456
return
450457
}
451458
ierr = err
452459
return
453460
}
454-
if fn := s.RefreshingScopeHandler; fn != nil && !fn(scope, rti.GetScope()) {
461+
462+
allowed, err := scopeFn(scope, rti.GetScope())
463+
if err != nil {
464+
ierr = err
465+
return
466+
} else if !allowed {
455467
rerr = errors.ErrInvalidScope
456468
return
457469
}
@@ -461,8 +473,8 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
461473
if ierr == errors.ErrInvalidClient {
462474
rerr = errors.ErrInvalidClient
463475
ierr = nil
464-
} else if ierr == errors.ErrInvalidRefreshToken {
465-
rerr = errors.ErrInvalidRefreshToken
476+
} else if ierr == errors.ErrInvalidRefreshToken || ierr == errors.ErrExpiredRefreshToken {
477+
rerr = errors.ErrInvalidGrant
466478
ierr = nil
467479
}
468480
}
@@ -484,6 +496,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{})
484496
if refresh := ti.GetRefresh(); refresh != "" {
485497
data["refresh_token"] = refresh
486498
}
499+
if fn := s.ExtensionFieldsHandler; fn != nil {
500+
ext := fn(ti)
501+
for k, v := range ext {
502+
if _, ok := data[k]; ok {
503+
continue
504+
}
505+
data[k] = v
506+
}
507+
}
487508
return
488509
}
489510

@@ -504,17 +525,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
504525
err = s.resTokenError(w, r, rerr, ierr)
505526
return
506527
}
507-
tokenData := s.GetTokenData(ti)
508-
if fn := s.ExtensionFieldsHandler; fn != nil {
509-
ext := fn(w, r)
510-
for k, v := range ext {
511-
if _, ok := tokenData[k]; ok {
512-
continue
513-
}
514-
tokenData[k] = v
515-
}
516-
}
517-
err = s.resToken(w, tokenData)
528+
err = s.resToken(w, s.GetTokenData(ti))
518529
return
519530
}
520531

0 commit comments

Comments
 (0)