Skip to content

Commit 3db72b0

Browse files
authored
Convert github auth code to use x/oauth2 (#48598)
Replaces the dependency on go-oidc/oauth2 with x/oauth2 in the github connector auth flows. One result of the switch is that it permitted a lot of simplification - the github client cache used by the auth server was removed entirely.
1 parent 79a1680 commit 3db72b0

File tree

2 files changed

+39
-106
lines changed

2 files changed

+39
-106
lines changed

lib/auth/auth.go

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ import (
4949
"sync"
5050
"time"
5151

52-
"github.com/coreos/go-oidc/oauth2"
5352
"github.com/google/uuid"
5453
liblicense "github.com/gravitational/license"
5554
"github.com/gravitational/trace"
@@ -494,7 +493,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
494493
Authority: cfg.Authority,
495494
AuthServiceName: cfg.AuthServiceName,
496495
ServerID: cfg.HostUUID,
497-
githubClients: make(map[string]*githubClient),
498496
cancelFunc: cancelFunc,
499497
closeCtx: closeCtx,
500498
emitter: cfg.Emitter,
@@ -886,10 +884,9 @@ type ReadOnlyCache = readonly.Cache
886884
// - same for users and their sessions
887885
// - checks public keys to see if they're signed by it (can be trusted or not)
888886
type Server struct {
889-
lock sync.RWMutex
890-
githubClients map[string]*githubClient
891-
clock clockwork.Clock
892-
bk backend.Backend
887+
lock sync.RWMutex
888+
clock clockwork.Clock
889+
bk backend.Backend
893890

894891
closeCtx context.Context
895892
cancelFunc context.CancelFunc
@@ -7523,43 +7520,6 @@ func (k *authKeepAliver) Close() error {
75237520
return nil
75247521
}
75257522

7526-
// githubClient is internal structure that stores Github OAuth 2client and its config
7527-
type githubClient struct {
7528-
client *oauth2.Client
7529-
config oauth2.Config
7530-
}
7531-
7532-
// oauth2ConfigsEqual returns true if the provided OAuth2 configs are equal
7533-
func oauth2ConfigsEqual(a, b oauth2.Config) bool {
7534-
if a.Credentials.ID != b.Credentials.ID {
7535-
return false
7536-
}
7537-
if a.Credentials.Secret != b.Credentials.Secret {
7538-
return false
7539-
}
7540-
if a.RedirectURL != b.RedirectURL {
7541-
return false
7542-
}
7543-
if len(a.Scope) != len(b.Scope) {
7544-
return false
7545-
}
7546-
for i := range a.Scope {
7547-
if a.Scope[i] != b.Scope[i] {
7548-
return false
7549-
}
7550-
}
7551-
if a.AuthURL != b.AuthURL {
7552-
return false
7553-
}
7554-
if a.TokenURL != b.TokenURL {
7555-
return false
7556-
}
7557-
if a.AuthMethod != b.AuthMethod {
7558-
return false
7559-
}
7560-
return true
7561-
}
7562-
75637523
// DefaultDNSNamesForRole returns default DNS names for the specified role.
75647524
func DefaultDNSNamesForRole(role types.SystemRole) []string {
75657525
if (types.SystemRoles{role}).IncludeAny(

lib/auth/github.go

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import (
3232
"strings"
3333
"time"
3434

35-
"github.com/coreos/go-oidc/oauth2"
3635
"github.com/gravitational/trace"
3736
"github.com/sirupsen/logrus"
37+
"golang.org/x/oauth2"
3838

3939
"github.com/gravitational/teleport"
4040
"github.com/gravitational/teleport/api/constants"
@@ -143,7 +143,7 @@ func (g *GithubConverter) UpdateGithubConnector(ctx context.Context, connector t
143143

144144
// CreateGithubAuthRequest creates a new request for Github OAuth2 flow
145145
func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAuthRequest) (*types.GithubAuthRequest, error) {
146-
connector, client, err := a.getGithubConnectorAndClient(ctx, req)
146+
connector, err := a.getGithubConnector(ctx, req)
147147
if err != nil {
148148
return nil, trace.Wrap(err)
149149
}
@@ -163,7 +163,10 @@ func (a *Server) CreateGithubAuthRequest(ctx context.Context, req types.GithubAu
163163
if err != nil {
164164
return nil, trace.Wrap(err)
165165
}
166-
req.RedirectURL = client.AuthCodeURL(req.StateToken, "", "")
166+
167+
config := newGithubOAuth2Config(connector)
168+
169+
req.RedirectURL = config.AuthCodeURL(req.StateToken)
167170
log.WithFields(logrus.Fields{teleport.ComponentKey: "github"}).Debugf(
168171
"Redirect URL: %v.", req.RedirectURL)
169172
req.SetExpiry(a.GetClock().Now().UTC().Add(defaults.GithubAuthRequestTTL))
@@ -487,86 +490,51 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, diag
487490
return auth, nil
488491
}
489492

490-
func (a *Server) getGithubConnectorAndClient(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, *oauth2.Client, error) {
493+
func (a *Server) getGithubConnector(ctx context.Context, request types.GithubAuthRequest) (types.GithubConnector, error) {
491494
if request.SSOTestFlow {
492495
if request.ConnectorSpec == nil {
493-
return nil, nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
496+
return nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true")
494497
}
495498

496499
if request.ConnectorID == "" {
497-
return nil, nil, trace.BadParameter("ConnectorID cannot be empty")
500+
return nil, trace.BadParameter("ConnectorID cannot be empty")
498501
}
499502

500503
// stateless test flow
501504
connector, err := services.NewGithubConnector(request.ConnectorID, *request.ConnectorSpec)
502505
if err != nil {
503-
return nil, nil, trace.Wrap(err)
504-
}
505-
506-
// construct client directly.
507-
config := newGithubOAuth2Config(connector)
508-
client, err := oauth2.NewClient(http.DefaultClient, config)
509-
if err != nil {
510-
return nil, nil, trace.Wrap(err)
506+
return nil, trace.Wrap(err)
511507
}
512508

513-
return connector, client, nil
509+
return connector, nil
514510
}
515511

516512
// regular execution flow
517513
connector, err := a.GetGithubConnector(ctx, request.ConnectorID, true)
518514
if err != nil {
519-
return nil, nil, trace.Wrap(err)
515+
return nil, trace.Wrap(err)
520516
}
521517
connector, err = services.InitGithubConnector(connector)
522518
if err != nil {
523-
return nil, nil, trace.Wrap(err)
524-
}
525-
526-
client, err := a.getGithubOAuth2Client(connector)
527-
if err != nil {
528-
return nil, nil, trace.Wrap(err)
519+
return nil, trace.Wrap(err)
529520
}
530521

531-
return connector, client, nil
522+
return connector, nil
532523
}
533524

534525
func newGithubOAuth2Config(connector types.GithubConnector) oauth2.Config {
535526
return oauth2.Config{
536-
Credentials: oauth2.ClientCredentials{
537-
ID: connector.GetClientID(),
538-
Secret: connector.GetClientSecret(),
527+
ClientID: connector.GetClientID(),
528+
ClientSecret: connector.GetClientSecret(),
529+
RedirectURL: connector.GetRedirectURL(),
530+
Scopes: GithubScopes,
531+
Endpoint: oauth2.Endpoint{
532+
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
533+
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
539534
},
540-
RedirectURL: connector.GetRedirectURL(),
541-
Scope: GithubScopes,
542-
AuthURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubAuthPath),
543-
TokenURL: fmt.Sprintf("%s/%s", connector.GetEndpointURL(), GithubTokenPath),
544535
}
545536
}
546537

547-
func (a *Server) getGithubOAuth2Client(connector types.GithubConnector) (*oauth2.Client, error) {
548-
config := newGithubOAuth2Config(connector)
549-
550-
a.lock.Lock()
551-
defer a.lock.Unlock()
552-
553-
cachedClient, ok := a.githubClients[connector.GetName()]
554-
if ok && oauth2ConfigsEqual(cachedClient.config, config) {
555-
return cachedClient.client, nil
556-
}
557-
558-
delete(a.githubClients, connector.GetName())
559-
client, err := oauth2.NewClient(http.DefaultClient, config)
560-
if err != nil {
561-
return nil, trace.Wrap(err)
562-
}
563-
a.githubClients[connector.GetName()] = &githubClient{
564-
client: client,
565-
config: config,
566-
}
567-
return client, nil
568-
}
569-
570538
// ValidateGithubAuthCallback validates Github auth callback redirect
571539
func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*authclient.GithubAuthResponse, error) {
572540
logger := log.WithFields(logrus.Fields{teleport.ComponentKey: "github"})
@@ -584,19 +552,19 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia
584552

585553
// optional parameter: error_description
586554
errDesc := q.Get("error_description")
587-
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, errParam, q)
555+
oauthErr := trace.OAuth2("invalid_request", errParam, q)
588556
return nil, trace.WithUserMessage(oauthErr, "GitHub returned error: %v [%v]", errDesc, errParam)
589557
}
590558

591559
code := q.Get("code")
592560
if code == "" {
593-
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "code query param must be set", q)
561+
oauthErr := trace.OAuth2("invalid_request", "code query param must be set", q)
594562
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
595563
}
596564

597565
stateToken := q.Get("state")
598566
if stateToken == "" {
599-
oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q)
567+
oauthErr := trace.OAuth2("invalid_request", "missing state query param", q)
600568
return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from GitHub.")
601569
}
602570
diagCtx.RequestID = stateToken
@@ -607,15 +575,15 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia
607575
}
608576
diagCtx.Info.TestFlow = req.SSOTestFlow
609577

610-
connector, client, err := a.getGithubConnectorAndClient(ctx, *req)
578+
connector, err := a.getGithubConnector(ctx, *req)
611579
if err != nil {
612580
return nil, trace.Wrap(err, "Failed to get GitHub connector and client.")
613581
}
614582
diagCtx.Info.GithubTeamsToLogins = connector.GetTeamsToLogins()
615583
diagCtx.Info.GithubTeamsToRoles = connector.GetTeamsToRoles()
616584
logger.Debugf("Connector %q teams to logins: %v, roles: %v", connector.GetName(), connector.GetTeamsToLogins(), connector.GetTeamsToRoles())
617585

618-
userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, client, diagCtx, logger)
586+
userResp, teamsResp, err := a.getGithubUserAndTeams(ctx, connector, code, diagCtx, logger)
619587
if err != nil {
620588
return nil, trace.Wrap(err)
621589
}
@@ -752,7 +720,6 @@ func (a *Server) getGithubUserAndTeams(
752720
ctx context.Context,
753721
connector types.GithubConnector,
754722
code string,
755-
client *oauth2.Client,
756723
diagCtx *SSODiagContext,
757724
logger *logrus.Entry,
758725
) (*GithubUserResponse, []GithubTeamResponse, error) {
@@ -762,20 +729,26 @@ func (a *Server) getGithubUserAndTeams(
762729
return a.GithubUserAndTeamsOverride()
763730
}
764731

732+
config := newGithubOAuth2Config(connector)
733+
765734
// exchange the authorization code received by the callback for an access token
766-
token, err := client.RequestToken(oauth2.GrantTypeAuthCode, code)
735+
token, err := config.Exchange(ctx, code)
767736
if err != nil {
768737
return nil, nil, trace.Wrap(err, "Requesting GitHub OAuth2 token failed.")
769738
}
770739

740+
scope, ok := token.Extra("scope").(string)
741+
if !ok {
742+
return nil, nil, trace.BadParameter("missing or invalid scope found in GitHub OAuth2 token")
743+
}
771744
diagCtx.Info.GithubTokenInfo = &types.GithubTokenInfo{
772745
TokenType: token.TokenType,
773-
Expires: int64(token.Expires),
774-
Scope: token.Scope,
746+
Expires: token.ExpiresIn,
747+
Scope: scope,
775748
}
776749

777750
logger.Debugf("Obtained OAuth2 token: Type=%v Expires=%v Scope=%v.",
778-
token.TokenType, token.Expires, token.Scope)
751+
token.TokenType, token.ExpiresIn, scope)
779752

780753
// Get the Github organizations the user is a member of so we don't
781754
// make unnecessary API requests

0 commit comments

Comments
 (0)