@@ -22,16 +22,20 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
2222 if err := manager .CheckInterface (); err != nil {
2323 panic (err )
2424 }
25+
2526 srv := & Server {
2627 Config : cfg ,
2728 Manager : manager ,
2829 }
30+
2931 // default handler
3032 srv .ClientInfoHandler = ClientBasicHandler
33+
3134 srv .UserAuthorizationHandler = func (w http.ResponseWriter , r * http.Request ) (userID string , err error ) {
3235 err = errors .ErrAccessDenied
3336 return
3437 }
38+
3539 srv .PasswordAuthorizationHandler = func (username , password string ) (userID string , err error ) {
3640 err = errors .ErrAccessDenied
3741 return
@@ -86,10 +90,12 @@ func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, statu
8690 w .Header ().Set ("Content-Type" , "application/json;charset=UTF-8" )
8791 w .Header ().Set ("Cache-Control" , "no-store" )
8892 w .Header ().Set ("Pragma" , "no-cache" )
93+
8994 status := http .StatusOK
9095 if len (statusCode ) > 0 && statusCode [0 ] > 0 {
9196 status = statusCode [0 ]
9297 }
98+
9399 w .WriteHeader (status )
94100 err = json .NewEncoder (w ).Encode (data )
95101 return
@@ -101,13 +107,16 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface
101107 if err != nil {
102108 return
103109 }
110+
104111 q := u .Query ()
105112 if req .State != "" {
106113 q .Set ("state" , req .State )
107114 }
115+
108116 for k , v := range data {
109117 q .Set (k , fmt .Sprint (v ))
110118 }
119+
111120 switch req .ResponseType {
112121 case oauth2 .Code :
113122 u .RawQuery = q .Encode ()
@@ -118,6 +127,7 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface
118127 return
119128 }
120129 }
130+
121131 uri = u .String ()
122132 return
123133}
@@ -138,6 +148,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ
138148 if err != nil {
139149 return
140150 }
151+
141152 clientID := r .FormValue ("client_id" )
142153 if r .Method != "GET" ||
143154 clientID == "" ||
@@ -147,6 +158,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ
147158 }
148159
149160 resType := oauth2 .ResponseType (r .FormValue ("response_type" ))
161+
150162 if resType .String () == "" {
151163 err = errors .ErrUnsupportedResponseType
152164 return
@@ -170,9 +182,11 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,
170182 // check the client allows the grant type
171183 if fn := s .ClientAuthorizedHandler ; fn != nil {
172184 gt := oauth2 .AuthorizationCode
185+
173186 if req .ResponseType == oauth2 .Token {
174187 gt = oauth2 .Implicit
175188 }
189+
176190 allowed , verr := fn (req .ClientID , gt )
177191 if verr != nil {
178192 err = verr
@@ -185,6 +199,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,
185199
186200 // check the client allows the authorized scope
187201 if fn := s .ClientScopeHandler ; fn != nil {
202+
188203 allowed , verr := fn (req .ClientID , req .Scope )
189204 if verr != nil {
190205 err = verr
@@ -202,6 +217,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,
202217 Scope : req .Scope ,
203218 AccessTokenExp : req .AccessTokenExp ,
204219 }
220+
205221 ti , err = s .Manager .GenerateAuthToken (req .ResponseType , tgr )
206222 return
207223}
@@ -228,16 +244,19 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
228244
229245 // user authorization
230246 userID , verr := s .UserAuthorizationHandler (w , r )
247+
231248 if verr != nil {
232249 err = s .redirectError (w , req , verr )
233250 return
234251 } else if userID == "" {
235252 return
236253 }
254+
237255 req .UserID = userID
238256
239257 // specify the scope of authorization
240258 if fn := s .AuthorizeScopeHandler ; fn != nil {
259+
241260 scope , verr := fn (w , r )
242261 if verr != nil {
243262 err = verr
@@ -249,6 +268,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
249268
250269 // specify the expiration time of access token
251270 if fn := s .AccessTokenExpHandler ; fn != nil {
271+
252272 exp , verr := fn (w , r )
253273 if verr != nil {
254274 err = verr
@@ -262,6 +282,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
262282 err = s .redirectError (w , req , verr )
263283 return
264284 }
285+
265286 err = s .redirect (w , req , s .GetAuthorizeData (req .ResponseType , ti ))
266287 return
267288}
@@ -273,44 +294,59 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, t
273294 err = errors .ErrInvalidRequest
274295 return
275296 }
297+
276298 gt = oauth2 .GrantType (r .FormValue ("grant_type" ))
299+
277300 if gt .String () == "" {
278301 err = errors .ErrUnsupportedGrantType
279302 return
280303 }
304+
281305 clientID , clientSecret , err := s .ClientInfoHandler (r )
282306 if err != nil {
283307 return
284308 }
309+
285310 tgr = & oauth2.TokenGenerateRequest {
286311 ClientID : clientID ,
287312 ClientSecret : clientSecret ,
288313 }
314+
289315 switch gt {
290316 case oauth2 .AuthorizationCode :
291317 tgr .RedirectURI = r .FormValue ("redirect_uri" )
292318 tgr .Code = r .FormValue ("code" )
319+
293320 if tgr .RedirectURI == "" ||
294321 tgr .Code == "" {
295322 err = errors .ErrInvalidRequest
296323 return
297324 }
298325 case oauth2 .PasswordCredentials :
299326 tgr .Scope = r .FormValue ("scope" )
300- userID , verr := s .PasswordAuthorizationHandler (r .FormValue ("username" ), r .FormValue ("password" ))
327+ username , password := r .FormValue ("username" ), r .FormValue ("password" )
328+
329+ if username == "" || password == "" {
330+ err = errors .ErrInvalidRequest
331+ return
332+ }
333+
334+ userID , verr := s .PasswordAuthorizationHandler (username , password )
301335 if verr != nil {
302336 err = verr
303337 return
304338 } else if userID == "" {
305- err = errors .ErrInvalidRequest
339+ err = errors .ErrInvalidGrant
306340 return
307341 }
342+
308343 tgr .UserID = userID
309344 case oauth2 .ClientCredentials :
310345 tgr .Scope = r .FormValue ("scope" )
311346 case oauth2 .Refreshing :
312347 tgr .Refresh = r .FormValue ("refresh_token" )
313348 tgr .Scope = r .FormValue ("scope" )
349+
314350 if tgr .Refresh == "" {
315351 err = errors .ErrInvalidRequest
316352 }
@@ -350,6 +386,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
350386 case oauth2 .AuthorizationCode :
351387 ati , verr := s .Manager .GenerateAccessToken (gt , tgr )
352388 if verr != nil {
389+
353390 if verr == errors .ErrInvalidAuthorizeCode {
354391 err = errors .ErrInvalidGrant
355392 } else if verr == errors .ErrInvalidClient {
@@ -362,6 +399,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
362399 ti = ati
363400 case oauth2 .PasswordCredentials , oauth2 .ClientCredentials :
364401 if fn := s .ClientScopeHandler ; fn != nil {
402+
365403 allowed , verr := fn (tgr .ClientID , tgr .Scope )
366404 if verr != nil {
367405 err = verr
@@ -375,6 +413,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
375413 case oauth2 .Refreshing :
376414 // check scope
377415 if scope , scopeFn := tgr .Scope , s .RefreshingScopeHandler ; scope != "" && scopeFn != nil {
416+
378417 rti , verr := s .Manager .LoadRefreshToken (tgr .Refresh )
379418 if verr != nil {
380419 if verr == errors .ErrInvalidRefreshToken || verr == errors .ErrExpiredRefreshToken {
@@ -394,6 +433,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
394433 return
395434 }
396435 }
436+
397437 rti , verr := s .Manager .RefreshAccessToken (tgr )
398438 if verr != nil {
399439 if verr == errors .ErrInvalidRefreshToken || verr == errors .ErrExpiredRefreshToken {
@@ -416,12 +456,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{})
416456 "token_type" : s .Config .TokenType ,
417457 "expires_in" : int64 (ti .GetAccessExpiresIn () / time .Second ),
418458 }
459+
419460 if scope := ti .GetScope (); scope != "" {
420461 data ["scope" ] = scope
421462 }
463+
422464 if refresh := ti .GetRefresh (); refresh != "" {
423465 data ["refresh_token" ] = refresh
424466 }
467+
425468 if fn := s .ExtensionFieldsHandler ; fn != nil {
426469 ext := fn (ti )
427470 for k , v := range ext {
@@ -441,6 +484,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
441484 err = s .tokenError (w , verr )
442485 return
443486 }
487+
444488 ti , verr := s .GetAccessToken (gt , tgr )
445489 if verr != nil {
446490 err = s .tokenError (w , verr )
@@ -479,16 +523,21 @@ func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCod
479523 data = map [string ]interface {}{
480524 "error" : re .Error .Error (),
481525 }
526+
482527 if v := re .ErrorCode ; v != 0 {
483528 data ["error_code" ] = v
484529 }
530+
485531 if v := re .Description ; v != "" {
486532 data ["error_description" ] = v
487533 }
534+
488535 if v := re .URI ; v != "" {
489536 data ["error_uri" ] = v
490537 }
538+
491539 statusCode = 400
540+
492541 if v := re .StatusCode ; v > 0 {
493542 statusCode = v
494543 }
@@ -521,6 +570,7 @@ func (s *Server) ValidationBearerToken(r *http.Request) (ti oauth2.TokenInfo, er
521570 err = errors .ErrInvalidAccessToken
522571 return
523572 }
573+
524574 ti , err = s .Manager .LoadAccessToken (accessToken )
525575
526576 return
0 commit comments