diff --git a/api/api.go b/api/handler.go similarity index 94% rename from api/api.go rename to api/handler.go index a01c3b4..d648e72 100644 --- a/api/api.go +++ b/api/handler.go @@ -6,8 +6,23 @@ import ( "github.com/danielgtaylor/huma/v2" "github.com/labstack/echo/v4" + "go.uber.org/fx" ) +type Handler interface { + Area() string + Version() string + Register(*echo.Echo, huma.API) +} + +func AsHandler(f any) any { + return fx.Annotate( + f, + fx.As(new(Handler)), + fx.ResultTags(`group:"api-handler"`), + ) +} + type ErrorTransformerFunc func(context.Context, error) error type CRUDInfo struct { diff --git a/api/huma_adapter.go b/api/huma_adapter.go new file mode 100644 index 0000000..4e66b5c --- /dev/null +++ b/api/huma_adapter.go @@ -0,0 +1,175 @@ +package api + +import ( + "context" + "crypto/tls" + "io" + "mime/multipart" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/danielgtaylor/huma/v2" + "github.com/labstack/echo/v4" +) + +// MultipartMaxMemory is the maximum memory to use when parsing multipart +// form data. +var MultipartMaxMemory int64 = 8 * 1024 + +type echoCtx struct { + op *huma.Operation + orig echo.Context + status int +} + +// check that echoCtx implements huma.Context +var _ huma.Context = &echoCtx{} + +func (c *echoCtx) EchoContext() echo.Context { + return c.orig +} + +func (c *echoCtx) Request() *http.Request { + return c.orig.Request() +} + +func (c *echoCtx) Operation() *huma.Operation { + return c.op +} + +func (c *echoCtx) Context() context.Context { + return c.orig.Request().Context() +} + +func (c *echoCtx) Method() string { + return c.orig.Request().Method +} + +func (c *echoCtx) Host() string { + return c.orig.Request().Host +} + +func (c *echoCtx) RemoteAddr() string { + return c.orig.Request().RemoteAddr +} + +func (c *echoCtx) URL() url.URL { + return *c.orig.Request().URL +} + +func (c *echoCtx) Param(name string) string { + return c.orig.Param(name) +} + +func (c *echoCtx) Query(name string) string { + return c.orig.QueryParam(name) +} + +func (c *echoCtx) Header(name string) string { + return c.orig.Request().Header.Get(name) +} + +func (c *echoCtx) EachHeader(cb func(name, value string)) { + for name, values := range c.orig.Request().Header { + for _, value := range values { + cb(name, value) + } + } +} + +func (c *echoCtx) BodyReader() io.Reader { + return c.orig.Request().Body +} + +func (c *echoCtx) GetMultipartForm() (*multipart.Form, error) { + err := c.orig.Request().ParseMultipartForm(MultipartMaxMemory) + return c.orig.Request().MultipartForm, err +} + +func (c *echoCtx) SetReadDeadline(deadline time.Time) error { + return huma.SetReadDeadline(c.orig.Response(), deadline) +} + +func (c *echoCtx) SetStatus(code int) { + c.status = code + c.orig.Response().WriteHeader(code) +} + +func (c *echoCtx) Status() int { + return c.status +} + +func (c *echoCtx) AppendHeader(name, value string) { + c.orig.Response().Header().Add(name, value) +} + +func (c *echoCtx) SetHeader(name, value string) { + c.orig.Response().Header().Set(name, value) +} + +func (c *echoCtx) BodyWriter() io.Writer { + return c.orig.Response() +} + +func (c *echoCtx) TLS() *tls.ConnectionState { + return c.orig.Request().TLS +} + +func (c *echoCtx) Version() huma.ProtoVersion { + r := c.orig.Request() + return huma.ProtoVersion{ + Proto: r.Proto, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + } +} + +func (c *echoCtx) reset(op *huma.Operation, orig echo.Context) { + c.op = op + c.orig = orig + c.status = 0 +} + +type router interface { + Add(method, path string, handler echo.HandlerFunc, middlewares ...echo.MiddlewareFunc) *echo.Route +} + +type echoAdapter struct { + http.Handler + router router + pool *sync.Pool +} + +func (a *echoAdapter) Handle(op *huma.Operation, handler func(huma.Context)) { + // Convert {param} to :param + path := op.Path + path = strings.ReplaceAll(path, "{", ":") + path = strings.ReplaceAll(path, "}", "") + a.router.Add(op.Method, path, func(c echo.Context) error { + ctx := a.pool.Get().(*echoCtx) + ctx.reset(op, c) + + defer func() { + ctx.reset(nil, nil) + a.pool.Put(ctx) + }() + + handler(ctx) + return nil + }) +} + +func NewAdapter(r *echo.Echo, g *echo.Group) huma.Adapter { + return &echoAdapter{ + Handler: r, + router: g, + pool: &sync.Pool{ + New: func() any { + return new(echoCtx) + }, + }, + } +} diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..477ec35 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,62 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/danielgtaylor/huma/v2" + "github.com/labstack/echo/v4" + "go.uber.org/fx" + "go.uber.org/zap" +) + +type Middleware struct { + Name string + Middleware func(huma.API) func(huma.Context, func(huma.Context)) +} + +func NewMiddleware(name string, middleware func(huma.API) func(huma.Context, func(huma.Context))) Middleware { + return Middleware{ + Name: name, + Middleware: middleware, + } +} + +func AsMiddleware(middleware any) any { + return fx.Annotate( + middleware, + fx.ResultTags(`group:"api-middleware"`), + ) +} + +func AsMiddlewareFunc(name string, middleware func(huma.API) func(ctx huma.Context, next func(huma.Context))) any { + return AsMiddleware(NewMiddleware(name, middleware)) +} + +type Authorizer func(huma.Context) error + +func AuthorizationMiddleware(authorizer Authorizer, logger *zap.Logger) Middleware { + return NewMiddleware("authorization", func(api huma.API) func(huma.Context, func(huma.Context)) { + unauthorized := func(ctx huma.Context, errs ...error) { + status := http.StatusUnauthorized + message := http.StatusText(status) + + if strings.HasPrefix(strings.ToLower(ctx.Header(echo.HeaderAuthorization)), "basic ") { + ctx.SetHeader(echo.HeaderWWWAuthenticate, "basic realm=Restricted") + } + + if err := huma.WriteErr(api, ctx, status, message, errs...); err != nil { + logger.Error("huma api: failed to write error", zap.Error(err)) + } + } + + return func(ctx huma.Context, next func(huma.Context)) { + if err := authorizer(ctx); err != nil { + unauthorized(ctx) + logger.Error("huma api: failed to authorize", zap.Error(err)) + return + } + next(ctx) + } + }) +} diff --git a/authorization.go b/authorization.go deleted file mode 100644 index befb04c..0000000 --- a/authorization.go +++ /dev/null @@ -1,105 +0,0 @@ -package echox - -import ( - "context" - "errors" - - "github.com/gowool/echox/rbac" -) - -var ErrDeny = errors.New("huma api: authorizer decision `deny`") - -type ( - claimsKey struct{} - assertionsKey struct{} -) - -type Subject interface { - Identifier() string - Roles() []string -} - -type Claims struct { - Subject Subject - Metadata map[string]any -} - -func WithClaims(ctx context.Context, claims *Claims) context.Context { - return context.WithValue(ctx, claimsKey{}, claims) -} - -func CtxClaims(ctx context.Context) *Claims { - claims, _ := ctx.Value(claimsKey{}).(*Claims) - return claims -} - -func WithAssertions(ctx context.Context, assertions ...rbac.Assertion) context.Context { - return context.WithValue(ctx, assertionsKey{}, assertions) -} - -func CtxAssertions(ctx context.Context) []rbac.Assertion { - assertions, _ := ctx.Value(assertionsKey{}).([]rbac.Assertion) - return append(make([]rbac.Assertion, 0, len(assertions)), assertions...) -} - -type Target struct { - Action string - Assertions []rbac.Assertion - Metadata map[string]any -} - -type Decision int8 - -const ( - DecisionDeny = iota + 1 - DecisionAllow -) - -func (d Decision) String() string { - switch d { - case DecisionDeny: - return "deny" - case DecisionAllow: - return "allow" - default: - return "unknown" - } -} - -type Authorizer interface { - Authorize(ctx context.Context, claims *Claims, target *Target) (Decision, error) -} - -type DefaultAuthorizer struct { - rbac *rbac.RBAC -} - -func NewDefaultAuthorizer(rbac *rbac.RBAC) *DefaultAuthorizer { - return &DefaultAuthorizer{rbac: rbac} -} - -func (a *DefaultAuthorizer) Authorize(ctx context.Context, claims *Claims, target *Target) (d Decision, err error) { - d = DecisionDeny - err = ErrDeny - - if target == nil || target.Action == "" { - return - } - - if claims == nil || claims.Subject == nil { - return - } - - roles := make([]string, 0, len(claims.Subject.Roles())+1) - roles = append(roles, claims.Subject.Identifier()) - roles = append(roles, claims.Subject.Roles()...) - - for _, role := range roles { - granted, err1 := a.rbac.IsGrantedE(ctx, role, target.Action, target.Assertions...) - if granted && err1 == nil { - return DecisionAllow, nil - } - err = errors.Join(err, err1) - } - return -} diff --git a/config.go b/config.go index 019a3cc..07f256f 100644 --- a/config.go +++ b/config.go @@ -9,8 +9,6 @@ import ( "github.com/labstack/echo/v4/middleware" "go.uber.org/fx" "go.uber.org/zap" - - "github.com/gowool/echox/rbac" ) type MiddlewaresConfig struct { @@ -21,7 +19,6 @@ type MiddlewaresConfig struct { Secure SecureConfig `json:"secure,omitempty" yaml:"secure,omitempty"` CORS CORSConfig `json:"cors,omitempty" yaml:"cors,omitempty"` CSRF CSRFConfig `json:"csrf,omitempty" yaml:"csrf,omitempty"` - Session SessionConfig `json:"session,omitempty" yaml:"session,omitempty"` Logger RequestLoggerConfig `json:"logger,omitempty" yaml:"logger,omitempty"` } @@ -68,7 +65,6 @@ type RouterConfig struct { } type Config struct { - Security rbac.Config `json:"security,omitempty" yaml:"security,omitempty"` Middlewares MiddlewaresConfig `json:"middlewares,omitempty" yaml:"middlewares,omitempty"` Router RouterConfig `json:"router,omitempty" yaml:"router,omitempty"` } @@ -97,22 +93,6 @@ func (s SameSiteType) HTTP() http.SameSite { } } -type SessionConfig struct { - Skipper middleware.Skipper `json:"-" yaml:"-"` - CleanupInterval time.Duration `json:"cleanupInterval,omitempty" yaml:"cleanupInterval,omitempty"` - IdleTimeout time.Duration `json:"idleTimeout,omitempty" yaml:"idleTimeout,omitempty"` - Lifetime time.Duration `json:"lifetime,omitempty" yaml:"lifetime,omitempty"` - Cookie struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Domain string `json:"domain,omitempty" yaml:"domain,omitempty"` - Path string `json:"path,omitempty" yaml:"path,omitempty"` - Persist bool `json:"persist,omitempty" yaml:"persist,omitempty"` - Secure bool `json:"secure,omitempty" yaml:"secure,omitempty"` - HTTPOnly bool `json:"httpOnly,omitempty" yaml:"httpOnly,omitempty"` - SameSite SameSiteType `json:"sameSite,omitempty" yaml:"sameSite,omitempty"` - } `json:"cookie,omitempty" yaml:"cookie,omitempty"` -} - type RecoverConfig struct { // Skipper defines a function to skip middleware. Skipper middleware.Skipper diff --git a/echo.go b/echo.go index 154342b..6def4a1 100644 --- a/echo.go +++ b/echo.go @@ -9,9 +9,10 @@ import ( "strings" "github.com/danielgtaylor/huma/v2" - "github.com/danielgtaylor/huma/v2/adapters/humaecho" "github.com/labstack/echo/v4" "go.uber.org/fx" + + "github.com/gowool/echox/api" ) type areaKey struct{} @@ -28,11 +29,11 @@ type EchoParams struct { Renderer echo.Renderer Validator echo.Validator IPExtractor echo.IPExtractor - Filesystem fs.FS `name:"echo-fs"` - Handlers []Handler `group:"handler"` - Middlewares []Middleware `group:"middleware"` - APIHandlers []APIHandler `group:"api-handler"` - APIMiddlewares []APIMiddleware `group:"api-middleware"` + Filesystem fs.FS `name:"echo-fs"` + Handlers []Handler `group:"handler"` + Middlewares []Middleware `group:"middleware"` + APIHandlers []api.Handler `group:"api-handler"` + APIMiddlewares []api.Middleware `group:"api-middleware"` } func NewEcho(params EchoParams) *echo.Echo { @@ -67,7 +68,7 @@ func NewEcho(params EchoParams) *echo.Echo { handlers[handler.Area()] = append(handlers[handler.Area()], handler) } - apiHandlers := make(map[string][]APIHandler) + apiHandlers := make(map[string][]api.Handler) for _, apiHandler := range params.APIHandlers { key := fmt.Sprintf("%s-%s", apiHandler.Area(), apiHandler.Version()) apiHandlers[key] = append(apiHandlers[key], apiHandler) @@ -125,16 +126,16 @@ func NewEcho(params EchoParams) *echo.Echo { humaConfig.Components = &cfgAPI.Components humaConfig.Components.Schemas = schemas - api := humaecho.NewWithGroup(e, group.Group(cfgAPI.Path), humaConfig) + humaAPI := huma.NewAPI(humaConfig, api.NewAdapter(e, group.Group(cfgAPI.Path))) for _, name := range cfgAPI.Middlewares { if mdw, ok := apiMiddlewares[name]; ok { - api.UseMiddleware(mdw(api)) + humaAPI.UseMiddleware(mdw(humaAPI)) } } for _, h := range apiHandlers[fmt.Sprintf("%s-%s", area, version)] { - h.Register(e, api) + h.Register(e, humaAPI) } } } diff --git a/go.mod b/go.mod index 17f5116..604633d 100644 --- a/go.mod +++ b/go.mod @@ -3,24 +3,20 @@ module github.com/gowool/echox go 1.23.2 require ( - github.com/alexedwards/scs/v2 v2.8.0 - github.com/danielgtaylor/huma/v2 v2.25.0 + github.com/danielgtaylor/huma/v2 v2.26.0 github.com/google/uuid v1.6.0 github.com/gowool/cr v0.0.1 github.com/labstack/echo/v4 v4.12.0 - github.com/stretchr/testify v1.9.0 go.uber.org/fx v1.23.0 go.uber.org/zap v1.27.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/oklog/ulid/v2 v2.1.1-0.20240413180941-96c4edf226ef // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect go.uber.org/dig v1.18.0 // indirect @@ -30,5 +26,4 @@ require ( golang.org/x/sys v0.27.0 // indirect golang.org/x/text v0.20.0 // indirect golang.org/x/time v0.8.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3fdfa99..ecb9973 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ -github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw= -github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= -github.com/danielgtaylor/huma/v2 v2.25.0 h1:8q/tZLozDs2oFPUHS1xaFVa1mlNYBXV8UbmSQUQeAXo= -github.com/danielgtaylor/huma/v2 v2.25.0/go.mod h1:NbSFXRoOMh3BVmiLJQ9EbUpnPas7D9BeOxF/pZBAGa0= +github.com/danielgtaylor/huma/v2 v2.26.0 h1:lON4pIcckuSQJNDi6WkOu0sS7mxvlNkTAGbc3BrRXTc= +github.com/danielgtaylor/huma/v2 v2.26.0/go.mod h1:NbSFXRoOMh3BVmiLJQ9EbUpnPas7D9BeOxF/pZBAGa0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= @@ -54,7 +52,5 @@ golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go index 1216076..c19e0cc 100644 --- a/handler.go +++ b/handler.go @@ -1,7 +1,6 @@ package echox import ( - "github.com/danielgtaylor/huma/v2" "github.com/labstack/echo/v4" "go.uber.org/fx" ) @@ -18,17 +17,3 @@ func AsHandler(f any) any { fx.ResultTags(`group:"handler"`), ) } - -type APIHandler interface { - Area() string - Version() string - Register(*echo.Echo, huma.API) -} - -func AsAPIHandler(f any) any { - return fx.Annotate( - f, - fx.As(new(APIHandler)), - fx.ResultTags(`group:"api-handler"`), - ) -} diff --git a/middleware.go b/middleware.go index 9b0f126..58d6f84 100644 --- a/middleware.go +++ b/middleware.go @@ -2,20 +2,14 @@ package echox import ( "errors" - "fmt" "net/http" "strings" - "sync" - "github.com/alexedwards/scs/v2" - "github.com/danielgtaylor/huma/v2" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "go.uber.org/fx" "go.uber.org/zap" - - "github.com/gowool/echox/rbac" ) type Middleware struct { @@ -41,29 +35,6 @@ func AsMiddlewareFunc(name string, middleware echo.MiddlewareFunc) any { return AsMiddleware(NewMiddleware(name, middleware)) } -type APIMiddleware struct { - Name string - Middleware func(huma.API) func(huma.Context, func(huma.Context)) -} - -func NewAPIMiddleware(name string, middleware func(huma.API) func(huma.Context, func(huma.Context))) APIMiddleware { - return APIMiddleware{ - Name: name, - Middleware: middleware, - } -} - -func AsAPIMiddleware(middleware any) any { - return fx.Annotate( - middleware, - fx.ResultTags(`group:"api-middleware"`), - ) -} - -func AsAPIMiddlewareFunc(name string, middleware func(huma.API) func(ctx huma.Context, next func(huma.Context))) any { - return AsAPIMiddleware(NewAPIMiddleware(name, middleware)) -} - func RecoverMiddleware(cfg RecoverConfig, logger *zap.Logger) Middleware { return NewMiddleware("recover", middleware.RecoverWithConfig(middleware.RecoverConfig{ Skipper: cfg.Skipper, @@ -213,31 +184,6 @@ func CSRFMiddleware(cfg CSRFConfig) Middleware { })) } -func SessionMiddleware(cfg SessionConfig, sessionManager *scs.SessionManager) Middleware { - if cfg.Skipper == nil { - cfg.Skipper = middleware.DefaultSkipper - } - - return NewMiddleware("session", func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { - if cfg.Skipper(c) { - return next(c) - } - - sessionManager.ErrorFunc = func(_ http.ResponseWriter, _ *http.Request, err1 error) { - err = err1 - } - - sessionManager.LoadAndSave(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.SetRequest(r) - c.SetResponse(echo.NewResponse(w, c.Echo())) - err = next(c) - })).ServeHTTP(c.Response(), c.Request()) - return - } - }) -} - func BasicAuthMiddleware(validator middleware.BasicAuthValidator) Middleware { return NewMiddleware("basic-auth", middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{ Skipper: func(c echo.Context) bool { @@ -258,13 +204,9 @@ func BearerAuthMiddleware(validator middleware.KeyAuthValidator) Middleware { })) } -func AuthorizationMiddleware(authorizer Authorizer) Middleware { - pool := &sync.Pool{ - New: func() any { - return new(Target) - }, - } +type Authorizer func(*http.Request) error +func AuthorizationMiddleware(authorizer Authorizer) Middleware { unauthorized := func(c echo.Context, errs ...error) error { h := c.Request().Header.Get(echo.HeaderAuthorization) if strings.HasPrefix(strings.ToLower(h), "basic ") { @@ -273,122 +215,12 @@ func AuthorizationMiddleware(authorizer Authorizer) Middleware { return echo.ErrUnauthorized.WithInternal(errors.Join(errs...)) } - fn := func(c echo.Context) (err error) { - var decision Decision = DecisionDeny - defer func() { - if decision == DecisionDeny && err == nil { - err = ErrDeny - } - }() - - ctx := c.Request().Context() - claims := CtxClaims(ctx) - assertions := CtxAssertions(ctx) - - target := pool.Get().(*Target) - defer pool.Put(target) - - for _, action := range permissions(c.Request().Method, c.Request().URL.Path) { - target.Action = action - target.Assertions = assertions - - if decision, err = authorizer.Authorize(ctx, claims, target); decision == DecisionAllow { - return nil - } - } - return - } - return NewMiddleware("authorization", func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if err := fn(c); err != nil { + if err := authorizer(c.Request()); err != nil { return unauthorized(c, err) } return next(c) } }) } - -func AuthorizationAPIMiddleware(authorizer Authorizer, logger *zap.Logger) APIMiddleware { - pool := &sync.Pool{ - New: func() any { - return new(Target) - }, - } - - return NewAPIMiddleware("authorization", func(api huma.API) func(huma.Context, func(huma.Context)) { - unauthorized := func(ctx huma.Context, errs ...error) { - status := http.StatusUnauthorized - message := http.StatusText(status) - - if strings.HasPrefix(strings.ToLower(ctx.Header(echo.HeaderAuthorization)), "basic ") { - ctx.SetHeader(echo.HeaderWWWAuthenticate, "basic realm=Restricted") - } - - if err := huma.WriteErr(api, ctx, status, message, errs...); err != nil { - logger.Error("huma api: failed to write error", zap.Error(err)) - } - } - - fn := func(c huma.Context) (err error) { - var decision Decision = DecisionDeny - defer func() { - if decision == DecisionDeny && err == nil { - err = ErrDeny - } - }() - - claims := CtxClaims(c.Context()) - assertions := CtxAssertions(c.Context()) - - if o := c.Operation(); o.Metadata != nil { - for _, value := range o.Metadata { - switch value := value.(type) { - case *Target: - value.Assertions = append(assertions, value.Assertions...) - decision, err = authorizer.Authorize(c.Context(), claims, value) - return - case rbac.Assertion: - assertions = append(assertions, value) - case []rbac.Assertion: - assertions = append(assertions, value...) - } - } - } - - target := pool.Get().(*Target) - defer pool.Put(target) - - for _, action := range permissions(c.Method(), c.URL().Path) { - target.Action = action - target.Assertions = assertions - - if decision, err = authorizer.Authorize(c.Context(), claims, target); decision == DecisionAllow { - return nil - } - } - return - } - - return func(ctx huma.Context, next func(huma.Context)) { - if err := fn(ctx); err != nil { - unauthorized(ctx) - logger.Error("huma api: failed to authorize", zap.Error(err)) - return - } - next(ctx) - } - }) -} - -func permissions(method, path string) []string { - if path == "" { - path = "/" - } - return []string{ - "*", - method, - path, - fmt.Sprintf("%s %s", method, path), - } -} diff --git a/options.go b/options.go index 0327472..5a3f175 100644 --- a/options.go +++ b/options.go @@ -3,20 +3,13 @@ package echox import ( "go.uber.org/fx" - "github.com/gowool/echox/rbac" + "github.com/gowool/echox/api" ) var ( OptionEcho = fx.Provide(NewEcho) OptionIPExtractor = fx.Provide(IPExtractor) - OptionRBAC = fx.Provide(rbac.New) - OptionRBACWithConfig = fx.Provide(rbac.NewWithConfig) - OptionAuthorizationChecker = fx.Provide(func(rbac *rbac.RBAC) rbac.AuthorizationChecker { return rbac }) - OptionAuthorizer = fx.Provide(fx.Annotate(NewDefaultAuthorizer, fx.As(new(Authorizer)))) - - OptionSessionManager = fx.Provide(NewSessionManager) - OptionRecoverMiddleware = fx.Provide(AsMiddleware(RecoverMiddleware)) OptionBodyLimitMiddleware = fx.Provide(AsMiddleware(BodyLimitMiddleware)) OptionCompressMiddleware = fx.Provide(AsMiddleware(CompressMiddleware)) @@ -26,10 +19,9 @@ var ( OptionSecureMiddleware = fx.Provide(AsMiddleware(SecureMiddleware)) OptionCORSMiddleware = fx.Provide(AsMiddleware(CORSMiddleware)) OptionCSRFMiddleware = fx.Provide(AsMiddleware(CSRFMiddleware)) - OptionSessionMiddleware = fx.Provide(AsMiddleware(SessionMiddleware)) OptionBasicAuthMiddleware = fx.Provide(AsMiddleware(BasicAuthMiddleware)) OptionBearerAuthMiddleware = fx.Provide(AsMiddleware(BearerAuthMiddleware)) OptionAuthorizationMiddleware = fx.Provide(AsMiddleware(AuthorizationMiddleware)) - OptionAuthorizationAPIMiddleware = fx.Provide(AsAPIMiddleware(AuthorizationAPIMiddleware)) + OptionAPIAuthorizationMiddleware = fx.Provide(api.AsMiddleware(api.AuthorizationMiddleware)) ) diff --git a/rbac/config.go b/rbac/config.go deleted file mode 100644 index 6654459..0000000 --- a/rbac/config.go +++ /dev/null @@ -1,73 +0,0 @@ -package rbac - -type RoleConfig struct { - Role string `json:"role,omitempty" yaml:"role,omitempty"` - Parents []string `json:"parents,omitempty" yaml:"parents,omitempty"` - Children []string `json:"children,omitempty" yaml:"children,omitempty"` -} - -type AccessConfig struct { - Role string `json:"role,omitempty" yaml:"role,omitempty"` - Permissions []string `json:"permissions,omitempty" yaml:"permissions,omitempty"` -} - -type Config struct { - CreateMissingRoles bool `json:"createMissingRoles,omitempty" yaml:"createMissingRoles,omitempty"` - RoleHierarchy []RoleConfig `json:"roleHierarchy,omitempty" yaml:"roleHierarchy,omitempty"` - AccessControl []AccessConfig `json:"accessControl,omitempty" yaml:"accessControl,omitempty"` -} - -func NewWithConfig(cfg Config) (*RBAC, error) { - rbac := New() - err := rbac.Apply(cfg) - return rbac, err -} - -func (rbac *RBAC) Apply(cfg Config) error { - rbac.SetCreateMissingRoles(cfg.CreateMissingRoles) - - for _, role := range cfg.RoleHierarchy { - if err := rbac.AddRole(role.Role); err != nil { - return err - } - } - - for _, role := range cfg.RoleHierarchy { - r, err := rbac.Role(role.Role) - if err != nil { - return err - } - - for _, parent := range role.Parents { - p, err := rbac.Role(parent) - if err != nil { - return err - } - if err = r.AddParent(p); err != nil { - return err - } - } - - for _, child := range role.Children { - c, err := rbac.Role(child) - if err != nil { - return err - } - if err = r.AddChild(c); err != nil { - return err - } - } - } - - for _, access := range cfg.AccessControl { - r, err := rbac.Role(access.Role) - if err != nil { - return err - } - if len(access.Permissions) == 0 { - continue - } - r.AddPermissions(access.Permissions[0], access.Permissions[1:]...) - } - return nil -} diff --git a/rbac/rbac.go b/rbac/rbac.go deleted file mode 100644 index 4eacb60..0000000 --- a/rbac/rbac.go +++ /dev/null @@ -1,140 +0,0 @@ -package rbac - -import ( - "context" - "errors" - "fmt" - "maps" - "slices" -) - -var ( - ErrRoleNotFound = errors.New("role not found") - ErrInvalidRole = errors.New("role must be a string or implement the Role interface") -) - -type Assertion interface { - Assert(ctx context.Context, role Role, permission string) (bool, error) -} - -type AssertionFunc func(ctx context.Context, role Role, permission string) (bool, error) - -func (f AssertionFunc) Assert(ctx context.Context, role Role, permission string) (bool, error) { - return f(ctx, role, permission) -} - -type AuthorizationChecker interface { - IsGranted(ctx context.Context, role any, permission string, assertions ...Assertion) bool -} - -type RBAC struct { - roles map[string]Role - createMissingRoles bool -} - -func New() *RBAC { - return &RBAC{roles: map[string]Role{}} -} - -func (rbac *RBAC) SetCreateMissingRoles(createMissingRoles bool) *RBAC { - rbac.createMissingRoles = createMissingRoles - return rbac -} - -func (rbac *RBAC) CreateMissingRoles() bool { - return rbac.createMissingRoles -} - -func (rbac *RBAC) Roles() []Role { - return slices.Collect(maps.Values(rbac.roles)) -} - -func (rbac *RBAC) Role(name string) (Role, error) { - if role, ok := rbac.roles[name]; ok { - return role, nil - } - return nil, fmt.Errorf(`%w: no role with name "%s" could be found`, ErrRoleNotFound, name) -} - -func (rbac *RBAC) HasRole(role any) (bool, error) { - switch role := role.(type) { - case string: - _, ok := rbac.roles[role] - return ok, nil - case Role, *DefaultRole, DefaultRole: - r, ok := rbac.roles[fmt.Sprintf("%s", role)] - return ok && r == role, nil - default: - return false, ErrInvalidRole - } -} - -func (rbac *RBAC) AddRole(role any, parents ...any) error { - var r Role - switch role := role.(type) { - case string: - r = NewRole(role) - case Role: - r = role - case DefaultRole: - r = &role - default: - return ErrInvalidRole - } - - for _, parent := range parents { - if rbac.createMissingRoles { - ok, err := rbac.HasRole(parent) - if err != nil { - return err - } - if !ok { - if err = rbac.AddRole(parent); err != nil { - return err - } - } - } - parentRole, err := rbac.Role(fmt.Sprintf("%s", parent)) - if err != nil { - return err - } - if err = parentRole.AddChild(r); err != nil { - return err - } - } - - rbac.roles[r.Name()] = r - return nil -} - -func (rbac *RBAC) IsGranted(ctx context.Context, role any, permission string, assertions ...Assertion) bool { - granted, err := rbac.IsGrantedE(ctx, role, permission, assertions...) - return granted && err == nil -} - -func (rbac *RBAC) IsGrantedE(ctx context.Context, role any, permission string, assertions ...Assertion) (bool, error) { - ok, err := rbac.HasRole(role) - if err != nil { - return false, err - } - if !ok { - return false, fmt.Errorf(`%w: no role with name "%s" could be found`, ErrRoleNotFound, role) - } - - r, err := rbac.Role(fmt.Sprintf("%s", role)) - if err != nil { - return false, err - } - - if !r.HasPermission(permission) { - return false, nil - } - - for _, assertion := range assertions { - if ok, err = assertion.Assert(ctx, r, permission); !ok || err != nil { - return false, err - } - } - - return true, nil -} diff --git a/rbac/rbac_test.go b/rbac/rbac_test.go deleted file mode 100644 index f0e87fd..0000000 --- a/rbac/rbac_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package rbac - -import ( - "context" - "testing" - - "github.com/stretchr/testify/suite" -) - -type testRole struct { - Role -} - -type simpleTrueAssertion struct{} - -func (*simpleTrueAssertion) Assert(context.Context, Role, string) (bool, error) { - return true, nil -} - -type simpleFalseAssertion struct{} - -func (*simpleFalseAssertion) Assert(context.Context, Role, string) (bool, error) { - return false, nil -} - -type roleMustMatchAssertion struct{} - -func (*roleMustMatchAssertion) Assert(_ context.Context, role Role, _ string) (bool, error) { - return role.Name() == "foo", nil -} - -type rbacSuit struct { - suite.Suite - rbac *RBAC -} - -func TestRBACSuite(t *testing.T) { - s := new(rbacSuit) - suite.Run(t, s) -} - -func (s *rbacSuit) SetupTest() { - s.rbac = New() -} - -func (s *rbacSuit) TestIsGrantedAssertion() { - foo := NewRole("foo") - bar := NewRole("bar") - - _true := new(simpleTrueAssertion) - _false := new(simpleFalseAssertion) - - roleNoMatch := new(roleMustMatchAssertion) - roleMatch := new(roleMustMatchAssertion) - - foo.AddPermissions("can.foo") - bar.AddPermissions("can.bar") - - s.Nil(s.rbac.AddRole(foo)) - s.Nil(s.rbac.AddRole(bar)) - - s.True(s.rbac.IsGranted(context.Background(), foo, "can.foo", _true)) - s.False(s.rbac.IsGranted(context.Background(), bar, "can.bar", _false)) - - s.False(s.rbac.IsGranted(context.Background(), foo, "cannot", _true)) - s.False(s.rbac.IsGranted(context.Background(), bar, "cannot", _false)) - - s.False(s.rbac.IsGranted(context.Background(), bar, "can.bar", roleNoMatch)) - s.False(s.rbac.IsGranted(context.Background(), bar, "can.foo", roleNoMatch)) - - s.True(s.rbac.IsGranted(context.Background(), foo, "can.foo", roleMatch)) -} - -func (s *rbacSuit) TestIsGrantedSingleRole() { - foo := NewRole("foo") - foo.AddPermissions("can.bar") - - s.Nil(s.rbac.AddRole(foo)) - - s.True(s.rbac.IsGranted(context.Background(), "foo", "can.bar")) - s.False(s.rbac.IsGranted(context.Background(), "foo", "can.baz")) -} - -func (s *rbacSuit) TestIsGrantedChildRoles() { - foo := NewRole("foo") - bar := NewRole("bar") - - foo.AddPermissions("can.foo") - bar.AddPermissions("can.bar") - - s.Nil(s.rbac.AddRole(foo)) - s.Nil(s.rbac.AddRole(bar, foo)) - - s.True(s.rbac.IsGranted(context.Background(), "foo", "can.bar")) - s.True(s.rbac.IsGranted(context.Background(), "foo", "can.foo")) - s.True(s.rbac.IsGranted(context.Background(), "bar", "can.bar")) - - s.False(s.rbac.IsGranted(context.Background(), "foo", "can.baz")) - s.False(s.rbac.IsGranted(context.Background(), "bar", "can.baz")) -} - -func (s *rbacSuit) TestIsGrantedWithInvalidRole() { - granted, err := s.rbac.IsGrantedE(context.Background(), "foo", "permission") - - s.False(granted) - s.ErrorIs(err, ErrRoleNotFound) -} - -func (s *rbacSuit) TestHasRole() { - foo := NewRole("foo") - snafu := testRole{Role: NewRole("snafu")} - - s.Nil(s.rbac.AddRole("bar")) - s.Nil(s.rbac.AddRole(foo)) - s.Nil(s.rbac.AddRole("snafu")) - - s.True(s.rbac.HasRole(foo)) - s.True(s.rbac.HasRole("bar")) - - s.False(s.rbac.HasRole("baz")) - - roleSnafu, err := s.rbac.Role("snafu") - - s.NoError(err) - s.NotNil(roleSnafu) - s.NotEqual(snafu, roleSnafu) - - s.True(s.rbac.HasRole("snafu")) - s.False(s.rbac.HasRole(snafu)) -} - -func (s *rbacSuit) TestAddRoleWithParentsUsingRBAC() { - foo := NewRole("foo") - bar := NewRole("bar") - - s.Nil(s.rbac.AddRole(foo)) - s.Nil(s.rbac.AddRole(bar, foo)) - - s.ElementsMatch([]Role{foo}, bar.Parents()) - s.ElementsMatch([]Role{bar}, foo.Children()) -} - -func (s *rbacSuit) TestAddRoleWithAutomaticParentsUsingRBAC() { - foo := NewRole("foo") - bar := NewRole("bar") - - s.rbac.SetCreateMissingRoles(true) - s.True(s.rbac.CreateMissingRoles()) - - s.Nil(s.rbac.AddRole(bar, foo)) - - s.ElementsMatch([]Role{foo}, bar.Parents()) - s.ElementsMatch([]Role{bar}, foo.Children()) -} - -func (s *rbacSuit) TestAddMultipleParentRole() { - adminRole := NewRole("Administrator") - adminRole.AddPermissions("user.manage") - s.Nil(s.rbac.AddRole(adminRole)) - - managerRole := NewRole("Manager") - managerRole.AddPermissions("post.publish") - s.Nil(s.rbac.AddRole(managerRole, "Administrator")) - - editorRole := NewRole("Editor") - editorRole.AddPermissions("post.edit") - s.Nil(s.rbac.AddRole(editorRole)) - - viewerRole := NewRole("Viewer") - viewerRole.AddPermissions("post.view") - s.Nil(s.rbac.AddRole(viewerRole, "Editor", "Manager")) - - s.Equal("Viewer", editorRole.Children()[0].Name()) - s.Equal("Viewer", managerRole.Children()[0].Name()) - s.True(s.rbac.IsGranted(context.Background(), "Editor", "post.view")) - s.True(s.rbac.IsGranted(context.Background(), "Manager", "post.view")) - - s.ElementsMatch([]Role{editorRole, managerRole}, viewerRole.Parents()) - s.ElementsMatch([]Role{adminRole}, managerRole.Parents()) - - s.Empty(editorRole.Parents()) - s.Empty(adminRole.Parents()) -} - -func (s *rbacSuit) TestAddParentRole() { - adminRole := NewRole("Administrator") - adminRole.AddPermissions("user.manage") - s.Nil(s.rbac.AddRole(adminRole)) - - managerRole := NewRole("Manager") - managerRole.AddPermissions("post.publish") - s.Nil(managerRole.AddParent(adminRole)) - s.Nil(s.rbac.AddRole(managerRole)) - - editorRole := NewRole("Editor") - editorRole.AddPermissions("post.edit") - s.Nil(s.rbac.AddRole(editorRole)) - - viewerRole := NewRole("Viewer") - viewerRole.AddPermissions("post.view") - s.Nil(viewerRole.AddParent(editorRole)) - s.Nil(viewerRole.AddParent(managerRole)) - s.Nil(s.rbac.AddRole(viewerRole)) - - s.ElementsMatch([]Role{viewerRole}, editorRole.Children()) - s.ElementsMatch([]Role{viewerRole}, managerRole.Children()) - s.ElementsMatch([]Role{editorRole, managerRole}, viewerRole.Parents()) - s.ElementsMatch([]Role{adminRole}, managerRole.Parents()) - - s.Empty(editorRole.Parents()) - s.Empty(adminRole.Parents()) - - s.True(s.rbac.IsGranted(context.Background(), "Editor", "post.view")) - s.True(s.rbac.IsGranted(context.Background(), "Editor", "post.edit")) - s.True(s.rbac.IsGranted(context.Background(), "Administrator", "post.view")) - s.True(s.rbac.IsGranted(context.Background(), "Administrator", "post.publish")) - s.False(s.rbac.IsGranted(context.Background(), "Administrator", "post.edit")) - s.True(s.rbac.IsGranted(context.Background(), "Manager", "post.view")) - s.False(s.rbac.IsGranted(context.Background(), "Manager", "post.edit")) - s.False(s.rbac.IsGranted(context.Background(), "Manager", "user.manage")) - s.True(s.rbac.IsGranted(context.Background(), "Viewer", "post.view")) - s.False(s.rbac.IsGranted(context.Background(), "Viewer", "post.edit")) - s.False(s.rbac.IsGranted(context.Background(), "Viewer", "post.publish")) - s.False(s.rbac.IsGranted(context.Background(), "Viewer", "user.manage")) - s.False(s.rbac.IsGranted(context.Background(), "Editor", "user.manage")) - s.False(s.rbac.IsGranted(context.Background(), "Editor", "post.publish")) -} diff --git a/rbac/role.go b/rbac/role.go deleted file mode 100644 index 1cde69a..0000000 --- a/rbac/role.go +++ /dev/null @@ -1,141 +0,0 @@ -package rbac - -import ( - "errors" - "fmt" - "maps" - "slices" -) - -var ErrCircularReference = errors.New("circular reference detected") - -type Role interface { - fmt.Stringer - Name() string - AddPermissions(permission string, rest ...string) - HasPermission(permission string) bool - Permissions(children bool) []string - AddParent(Role) error - Parents() []Role - AddChild(Role) error - Children() []Role - HasAncestor(role Role) bool - HasDescendant(role Role) bool -} - -type DefaultRole struct { - name string - permissions map[string]struct{} - parents map[string]Role - children map[string]Role -} - -func NewRole(name string) *DefaultRole { - return &DefaultRole{ - name: name, - permissions: map[string]struct{}{}, - parents: map[string]Role{}, - children: map[string]Role{}, - } -} - -func (r *DefaultRole) String() string { - return r.Name() -} - -func (r *DefaultRole) Name() string { - return r.name -} - -func (r *DefaultRole) AddPermissions(permission string, rest ...string) { - r.permissions[permission] = struct{}{} - for _, p := range rest { - r.permissions[p] = struct{}{} - } -} - -func (r *DefaultRole) HasPermission(permission string) bool { - if _, ok := r.permissions[permission]; ok { - return true - } - - for _, child := range r.children { - if child.HasPermission(permission) { - return true - } - } - - return false -} - -func (r *DefaultRole) Permissions(children bool) []string { - permissions := maps.Clone(r.permissions) - if children { - for _, child := range r.children { - for _, permission := range child.Permissions(children) { - permissions[permission] = struct{}{} - } - } - } - return slices.Collect(maps.Keys(permissions)) -} - -func (r *DefaultRole) AddParent(parent Role) error { - if r.HasDescendant(parent) { - return fmt.Errorf(`%w: to prevent circular references, you cannot add role "%s" as parent`, ErrCircularReference, parent.Name()) - } - - if _, ok := r.parents[parent.Name()]; ok { - return nil - } - - r.parents[parent.Name()] = parent - return parent.AddChild(r) -} - -func (r *DefaultRole) Parents() []Role { - return slices.Collect(maps.Values(r.parents)) -} - -func (r *DefaultRole) AddChild(child Role) error { - if r.HasAncestor(child) { - return fmt.Errorf(`%w: to prevent circular references, you cannot add role "%s" as child`, ErrCircularReference, child.Name()) - } - - if _, ok := r.children[child.Name()]; ok { - return nil - } - - r.children[child.Name()] = child - return child.AddParent(r) -} - -func (r *DefaultRole) Children() []Role { - return slices.Collect(maps.Values(r.children)) -} - -func (r *DefaultRole) HasAncestor(role Role) bool { - if _, ok := r.parents[role.Name()]; ok { - return true - } - - for _, parent := range r.parents { - if parent.HasAncestor(role) { - return true - } - } - return false -} - -func (r *DefaultRole) HasDescendant(role Role) bool { - if _, ok := r.children[role.Name()]; ok { - return true - } - - for _, child := range r.children { - if child.HasDescendant(role) { - return true - } - } - return false -} diff --git a/rbac/role_test.go b/rbac/role_test.go deleted file mode 100644 index 63c981b..0000000 --- a/rbac/role_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package rbac - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDefaultRole_Name(t *testing.T) { - role := NewRole("test") - assert.Equal(t, "test", role.Name()) -} - -func TestDefaultRole_AddPermissions(t *testing.T) { - role := NewRole("test") - role.AddPermissions("bar", "baz") - assert.True(t, role.HasPermission("bar")) - assert.True(t, role.HasPermission("baz")) -} - -func TestDefaultRole_AddChild(t *testing.T) { - foo := NewRole("foo") - bar := NewRole("bar") - baz := NewRole("baz") - - assert.Nil(t, foo.AddChild(bar)) - assert.Nil(t, foo.AddChild(baz)) - assert.ElementsMatch(t, []Role{bar, baz}, foo.Children()) -} - -func TestDefaultRole_AddParent(t *testing.T) { - foo := NewRole("foo") - bar := NewRole("bar") - baz := NewRole("baz") - - assert.Nil(t, foo.AddParent(bar)) - assert.Nil(t, foo.AddParent(baz)) - assert.ElementsMatch(t, []Role{bar, baz}, foo.Parents()) -} - -func TestDefaultRole_PermissionHierarchy(t *testing.T) { - foo := NewRole("foo") - foo.AddPermissions("foo.permission") - - bar := NewRole("bar") - bar.AddPermissions("bar.permission") - - baz := NewRole("baz") - baz.AddPermissions("baz.permission") - - assert.Nil(t, foo.AddParent(bar)) - assert.Nil(t, foo.AddChild(baz)) - - assert.True(t, bar.HasPermission("bar.permission")) - assert.True(t, bar.HasPermission("foo.permission")) - assert.True(t, bar.HasPermission("baz.permission")) - - assert.False(t, foo.HasPermission("bar.permission")) - assert.True(t, foo.HasPermission("foo.permission")) - assert.True(t, foo.HasPermission("baz.permission")) - - assert.False(t, baz.HasPermission("bar.permission")) - assert.False(t, baz.HasPermission("foo.permission")) - assert.True(t, baz.HasPermission("baz.permission")) -} - -func TestDefaultRole_CircleReferenceWithChild(t *testing.T) { - foo := NewRole("foo") - bar := NewRole("bar") - baz := NewRole("baz") - baz.AddPermissions("baz") - - assert.Nil(t, foo.AddChild(bar)) - assert.Nil(t, bar.AddChild(baz)) - assert.ErrorIs(t, baz.AddChild(foo), ErrCircularReference) -} - -func TestDefaultRole_CircleReferenceWithParent(t *testing.T) { - foo := NewRole("foo") - bar := NewRole("bar") - baz := NewRole("baz") - baz.AddPermissions("baz") - - assert.Nil(t, foo.AddParent(bar)) - assert.Nil(t, bar.AddParent(baz)) - assert.ErrorIs(t, baz.AddParent(foo), ErrCircularReference) -} - -func TestDefaultRole_Permissions(t *testing.T) { - foo := NewRole("foo") - foo.AddPermissions("foo.permission", "foo.2nd-permission") - - bar := NewRole("bar") - bar.AddPermissions("bar.permission") - - baz := NewRole("baz") - baz.AddPermissions("baz.permission") - - assert.Nil(t, foo.AddParent(bar)) - assert.Nil(t, foo.AddChild(baz)) - - expected := []string{"foo.permission", "foo.2nd-permission", "bar.permission", "baz.permission"} - assert.ElementsMatch(t, expected, bar.Permissions(true)) - - assert.ElementsMatch(t, []string{"bar.permission"}, bar.Permissions(false)) - - expected = []string{"foo.permission", "foo.2nd-permission", "baz.permission"} - assert.ElementsMatch(t, expected, foo.Permissions(true)) - - expected = []string{"foo.permission", "foo.2nd-permission"} - assert.ElementsMatch(t, expected, foo.Permissions(false)) - - assert.ElementsMatch(t, []string{"baz.permission"}, baz.Permissions(true)) - assert.ElementsMatch(t, []string{"baz.permission"}, baz.Permissions(false)) -} - -func TestDefaultRole_AddSameParent(t *testing.T) { - foo := NewRole("foo") - bar := NewRole("bar") - - assert.Nil(t, foo.AddParent(bar)) - assert.Nil(t, foo.AddParent(bar)) - - assert.ElementsMatch(t, []Role{bar}, foo.Parents()) -} diff --git a/session.go b/session.go deleted file mode 100644 index b30ae0d..0000000 --- a/session.go +++ /dev/null @@ -1,48 +0,0 @@ -package echox - -import ( - "time" - - "github.com/alexedwards/scs/v2" - "go.uber.org/fx" -) - -type SessionManagerParams struct { - fx.In - Config SessionConfig - Store scs.Store -} - -func NewSessionManager(params SessionManagerParams) *scs.SessionManager { - cfg := params.Config - if cfg.CleanupInterval == 0 { - cfg.CleanupInterval = 5 * time.Minute - } - if cfg.Lifetime == 0 { - cfg.Lifetime = 24 * time.Hour - } - if cfg.Cookie.Name == "" { - cfg.Cookie.Name = "session" - } - if cfg.Cookie.Path == "" { - cfg.Cookie.Path = "/" - } - if cfg.Cookie.SameSite == "" { - cfg.Cookie.SameSite = SameSiteLax - } - - sm := scs.New() - sm.Store = params.Store - sm.IdleTimeout = cfg.IdleTimeout - sm.Lifetime = cfg.Lifetime - sm.Cookie = scs.SessionCookie{ - Name: cfg.Cookie.Name, - Persist: cfg.Cookie.Persist, - Domain: cfg.Cookie.Domain, - Path: cfg.Cookie.Path, - HttpOnly: cfg.Cookie.HTTPOnly, - Secure: cfg.Cookie.Secure, - SameSite: cfg.Cookie.SameSite.HTTP(), - } - return sm -}