diff --git a/.gitignore b/.gitignore index 0486736..adf47c4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ .idea/ # environment -.env \ No newline at end of file +.env +.vercel diff --git a/api/me.go b/api/me.go new file mode 100644 index 0000000..f068c91 --- /dev/null +++ b/api/me.go @@ -0,0 +1,37 @@ +package api + +import ( + "net/http" + "net/url" + + "github.com/visiperf/visiauth/v3" + "github.com/visiperf/visiauth/v3/api/renderer" + "github.com/visiperf/visiauth/v3/errors" + "github.com/visiperf/visiauth/v3/neo4j" + "github.com/visiperf/visiauth/v3/redis" +) + +const ( + accessTokenQueryParamsKey = "token" +) + +func MeHandler(w http.ResponseWriter, r *http.Request) { + authenticable, err := func() (visiauth.Authenticable, error) { + vs, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + return nil, errors.InvalidArgument(err.Error(), "INVALID_QUERY_PARAMS") + } + + if len(vs.Get(accessTokenQueryParamsKey)) <= 0 { + return nil, errors.InvalidArgument("token in query params is required", "TOKEN_QUERY_PARAMS_REQUIRED") + } + + return visiauth.NewService(redis.NewJwkFetcher(), neo4j.NewUserRepository()).DecodeAccessToken(r.Context(), vs.Get(accessTokenQueryParamsKey)) + }() + if err != nil { + renderer.Error(err, w) + return + } + + renderer.Success(authenticable, http.StatusOK, w) +} diff --git a/api/renderer/render.go b/api/renderer/render.go new file mode 100644 index 0000000..3ac5d47 --- /dev/null +++ b/api/renderer/render.go @@ -0,0 +1,39 @@ +package renderer + +import ( + "encoding/json" + "errors" + "net/http" +) + +func Success(v interface{}, statusCode int, rw http.ResponseWriter) { + data, err := json.Marshal(map[string]interface{}{"data": v}) + if err != nil { + Error(err, rw) + return + } + + render(data, statusCode, "application/json", rw) +} + +func Error(err error, rw http.ResponseWriter) { + data, _ := json.Marshal(map[string]interface{}{"error": err}) + + var sce interface { + error + StatusCode() int + } + + sc := http.StatusInternalServerError + if errors.As(err, &sce) { + sc = sce.StatusCode() + } + + render(data, sc, "application/json", rw) +} + +func render(data []byte, statusCode int, contentType string, rw http.ResponseWriter) { + rw.Header().Set("Content-Type", contentType) + rw.WriteHeader(statusCode) + rw.Write(data) +} diff --git a/app.go b/app.go index ac603da..f0f47b3 100644 --- a/app.go +++ b/app.go @@ -1,5 +1,7 @@ package visiauth +import "encoding/json" + type App struct { id string } @@ -11,3 +13,11 @@ func NewApp(id string) *App { func (a App) ID() string { return a.id } + +func (a App) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + ID string `json:"id"` + }{ + ID: a.ID(), + }) +} diff --git a/errors/internal.go b/errors/internal.go new file mode 100644 index 0000000..2909a30 --- /dev/null +++ b/errors/internal.go @@ -0,0 +1,44 @@ +package errors + +import ( + "encoding/json" + "errors" + "net/http" +) + +type internal struct { + err error +} + +func Internal(err error) error { + return internal{err} +} + +func (i internal) Message() string { + return i.err.Error() +} + +func (i internal) StatusCode() int { + return http.StatusInternalServerError +} + +func (i internal) String() string { + return i.Message() +} + +func (i internal) Error() string { + return i.String() +} + +func (i internal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Message string `json:"message"` + }{ + Message: i.Message(), + }) +} + +func IsInternal(err error) bool { + var i internal + return errors.As(err, &i) +} diff --git a/errors/invalid-argument.go b/errors/invalid-argument.go new file mode 100644 index 0000000..65466bd --- /dev/null +++ b/errors/invalid-argument.go @@ -0,0 +1,50 @@ +package errors + +import ( + "encoding/json" + "errors" + "net/http" +) + +type invalidArgument struct { + reason, code string +} + +func InvalidArgument(reason, code string) error { + return invalidArgument{reason, code} +} + +func (ia invalidArgument) Message() string { + return ia.reason +} + +func (ia invalidArgument) Code() string { + return ia.code +} + +func (ia invalidArgument) StatusCode() int { + return http.StatusBadRequest +} + +func (ia invalidArgument) String() string { + return ia.Message() +} + +func (ia invalidArgument) Error() string { + return ia.String() +} + +func (ia invalidArgument) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Message string `json:"message"` + Code string `json:"code"` + }{ + Message: ia.Message(), + Code: ia.Code(), + }) +} + +func IsInvalidArgument(err error) bool { + var ia invalidArgument + return errors.As(err, &ia) +} diff --git a/errors/not-found.go b/errors/not-found.go new file mode 100644 index 0000000..0cbb75a --- /dev/null +++ b/errors/not-found.go @@ -0,0 +1,43 @@ +package errors + +import ( + "errors" + "fmt" + "net/http" +) + +type notFound struct { + resource, code string +} + +func NotFound(resource, code string) error { + return notFound{ + resource: resource, + code: code, + } +} + +func (nf notFound) Message() string { + return fmt.Sprintf("%s not found", nf.resource) +} + +func (nf notFound) Code() string { + return nf.code +} + +func (nf notFound) StatusCode() int { + return http.StatusNotFound +} + +func (nf notFound) String() string { + return nf.Message() +} + +func (nf notFound) Error() string { + return nf.String() +} + +func IsNotFound(err error) bool { + var nf notFound + return errors.As(err, &nf) +} diff --git a/errors/unauthorized.go b/errors/unauthorized.go new file mode 100644 index 0000000..b6d6e4b --- /dev/null +++ b/errors/unauthorized.go @@ -0,0 +1,50 @@ +package errors + +import ( + "encoding/json" + "errors" + "net/http" +) + +type unauthorized struct { + reason, code string +} + +func Unauthorized(reason, code string) error { + return unauthorized{reason, code} +} + +func (u unauthorized) Message() string { + return u.reason +} + +func (u unauthorized) Code() string { + return u.code +} + +func (u unauthorized) StatusCode() int { + return http.StatusUnauthorized +} + +func (u unauthorized) String() string { + return u.Message() +} + +func (u unauthorized) Error() string { + return u.String() +} + +func (u unauthorized) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Message string `json:"message"` + Code string `json:"code"` + }{ + Message: u.Message(), + Code: u.Code(), + }) +} + +func IsUnauthorized(err error) bool { + var u unauthorized + return errors.As(err, &u) +} diff --git a/neo4j/user.go b/neo4j/user.go index d1706aa..47b7391 100644 --- a/neo4j/user.go +++ b/neo4j/user.go @@ -5,6 +5,7 @@ import ( "github.com/neo4j/neo4j-go-driver/v4/neo4j" "github.com/visiperf/visiauth/v3" + "github.com/visiperf/visiauth/v3/errors" ) type UserRepository struct { @@ -22,7 +23,7 @@ func (r *UserRepository) FetchUserByID(ctx context.Context, userID string, scope c.Log = neo4j.ConsoleLogger(neo4j.ERROR) }) if err != nil { - return nil, err + return nil, errors.Internal(err) } defer driver.Close() @@ -49,12 +50,12 @@ func (r *UserRepository) fetchUserLegacyID(_ context.Context, session neo4j.Sess "user_id": userID, }) if err != nil { - return "", err + return "", errors.Internal(err) } rec, err := res.Single() if err != nil { - return "", err + return "", errors.Internal(err) } return rec.Values[0].(string), nil @@ -74,7 +75,7 @@ func (r *UserRepository) fetchUserOrganizations(_ context.Context, session neo4j "user_id": userID, }) if err != nil { - return nil, nil, err + return nil, nil, errors.Internal(err) } m := make(map[string]string) @@ -89,7 +90,7 @@ func (r *UserRepository) fetchUserOrganizations(_ context.Context, session neo4j } if err := res.Err(); err != nil { - return nil, nil, err + return nil, nil, errors.Internal(err) } return m, s, nil diff --git a/redis/errors.go b/redis/errors.go new file mode 100644 index 0000000..e0c7bd8 --- /dev/null +++ b/redis/errors.go @@ -0,0 +1,9 @@ +package redis + +const ( + ErrRedisNilMessage = "redis: nil" +) + +func IsErrRedisNilMessage(err error) bool { + return err.Error() == ErrRedisNilMessage +} diff --git a/redis/jwk.go b/redis/jwk.go index 55c6af7..dc28dac 100644 --- a/redis/jwk.go +++ b/redis/jwk.go @@ -5,6 +5,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/visiperf/visiauth/v3" + "github.com/visiperf/visiauth/v3/errors" ) type JwkFetcher struct { @@ -27,7 +28,11 @@ func (f *JwkFetcher) FetchJwk(ctx context.Context, kid string) (*visiauth.Jwk, e var jwk visiauth.Jwk if err := client.Get(ctx, kid).Scan(&jwk); err != nil { - return nil, err + if IsErrRedisNilMessage(err) { + return nil, errors.NotFound("jwk", "JWK_NOT_FOUND") + } + + return nil, errors.Internal(err) } return &jwk, nil diff --git a/service.go b/service.go index a1fc2db..4b6b366 100644 --- a/service.go +++ b/service.go @@ -2,8 +2,10 @@ package visiauth import ( "context" - "errors" + "fmt" "strings" + + "github.com/visiperf/visiauth/v3/errors" ) type Service struct { @@ -31,7 +33,7 @@ func (s *Service) DecodeAccessToken(ctx context.Context, accessToken string) (Au return s.app(ctx, t) } - return nil, errors.New("unknown token type") + return nil, errors.Internal(fmt.Errorf("unknown token type")) } func (s *Service) app(ctx context.Context, token *MachineToken) (Authenticable, error) { diff --git a/token.go b/token.go index cc4cc16..71b901e 100644 --- a/token.go +++ b/token.go @@ -4,13 +4,13 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io/ioutil" "net/http" "strings" "github.com/golang-jwt/jwt" + "github.com/visiperf/visiauth/v3/errors" "google.golang.org/grpc/metadata" ) @@ -24,8 +24,8 @@ const ( ) var ( - ErrMissingMetadata = errors.New("missing metadata") - ErrMissingAuthorization = errors.New("missing authorization") + ErrMissingMetadata = fmt.Errorf("missing metadata") + ErrMissingAuthorization = fmt.Errorf("missing authorization") ) var mTokenTypeTokenFactory = map[string]func(token *jwt.Token) Token{ @@ -146,7 +146,15 @@ func NewTokenParser(jwkFetcher JwkFetcher) *TokenParser { func (p *TokenParser) ParseToken(ctx context.Context, accessToken string) (Token, error) { token, err := jwt.Parse(accessToken, p.keyFunc(ctx)) if err != nil { - return nil, err + if e, ok := err.(*jwt.ValidationError); ok && e.Inner != nil { + err = e.Inner + } + + if errors.IsInternal(err) { + return nil, err + } + + return nil, errors.Unauthorized(err.Error(), "INVALID_TOKEN") } k := fmt.Sprintf("%s%s", token.Claims.(jwt.MapClaims)["iss"].(string), tokenTypeKey) @@ -154,7 +162,7 @@ func (p *TokenParser) ParseToken(ctx context.Context, accessToken string) (Token fn, ok := mTokenTypeTokenFactory[tt] if !ok { - return nil, errors.New("unknown token type") + return nil, errors.Internal(fmt.Errorf("unknown token type")) } return fn(token), nil diff --git a/user.go b/user.go index 2f05932..1d00347 100644 --- a/user.go +++ b/user.go @@ -2,6 +2,7 @@ package visiauth import ( "context" + "encoding/json" "github.com/bitrise-io/go-utils/sliceutil" "golang.org/x/exp/maps" @@ -84,3 +85,19 @@ func (u User) HighestRoleInOrganizations() map[string]string { func (u User) HighestRoleInOrganization(organizationId string) string { return u.organizationsRole[organizationId] } + +func (u User) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + ID string `json:"id"` + LegacyID string `json:"legacyId"` + Scopes []string `json:"scopes"` + OrganizationsRole map[string][]string `json:"organizationsRole"` + OrganizationLegacyIDs []string `json:"organizationLegacyIds"` + }{ + ID: u.ID(), + LegacyID: u.LegacyID(), + Scopes: u.Scopes(), + OrganizationsRole: u.OrganizationRoles(), + OrganizationLegacyIDs: u.OrganizationLegacyIds(), + }) +}