Skip to content

Commit

Permalink
feat: Allow CORS policy to be configured (#484)
Browse files Browse the repository at this point in the history
* Add configurable CORS policy in OpenIDProvider

* Add configurable CORS policy to Server

* remove duplicated CORS middleware

* Allow nil CORS policy to be set to disable CORS middleware

* create a separate handler on webServer so type assertion works in tests
  • Loading branch information
korylprince authored Nov 17, 2023
1 parent ce55068 commit 7b64687
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
27 changes: 25 additions & 2 deletions pkg/op/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,19 @@ type OpenIDProvider interface {

type HttpInterceptor func(http.Handler) http.Handler

type corsOptioner interface {
CORSOptions() *cors.Options
}

func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
if co, ok := o.(corsOptioner); ok {
if opts := co.CORSOptions(); opts != nil {
router.Use(cors.New(*opts).Handler)
}
} else {
router.Use(cors.New(defaultCORSOptions).Handler)
}
router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
Expand Down Expand Up @@ -224,6 +234,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is
storage: storage,
endpoints: DefaultEndpoints,
timer: make(<-chan time.Time),
corsOpts: &defaultCORSOptions,
logger: slog.Default(),
}

Expand Down Expand Up @@ -268,6 +279,7 @@ type Provider struct {
timer <-chan time.Time
accessTokenVerifierOpts []AccessTokenVerifierOpt
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
corsOpts *cors.Options
logger *slog.Logger
}

Expand Down Expand Up @@ -427,6 +439,10 @@ func (o *Provider) Probes() []ProbesFn {
}
}

func (o *Provider) CORSOptions() *cors.Options {
return o.corsOpts
}

func (o *Provider) Logger() *slog.Logger {
return o.logger
}
Expand Down Expand Up @@ -587,6 +603,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
}
}

func WithCORSOptions(opts *cors.Options) Option {
return func(o *Provider) error {
o.corsOpts = opts
return nil
}
}

// WithLogger lets a logger other than slog.Default().
//
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
Expand All @@ -603,6 +626,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle
for i := len(interceptors) - 1; i >= 0; i-- {
handler = interceptors[i](handler)
}
return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler))
return issuerInterceptor.Handler(handler)
}
}
17 changes: 15 additions & 2 deletions pkg/op/server_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
server: server,
endpoints: endpoints,
decoder: decoder,
corsOpts: &defaultCORSOptions,
logger: slog.Default(),
}
ws.router.Use(cors.New(defaultCORSOptions).Handler)

for _, option := range options {
option(ws)
}

ws.createRouter()
ws.handler = ws.router
if ws.corsOpts != nil {
ws.handler = cors.New(*ws.corsOpts).Handler(ws.router)
}
return ws
}

Expand Down Expand Up @@ -66,6 +70,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption {
}
}

// WithServerCORSOptions sets the CORS policy for the Server's router.
func WithServerCORSOptions(opts *cors.Options) ServerOption {
return func(s *webServer) {
s.corsOpts = opts
}
}

// WithFallbackLogger overrides the fallback logger, which
// is used when no logger was found in the context.
// Defaults to [slog.Default].
Expand All @@ -78,13 +89,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
type webServer struct {
server Server
router *chi.Mux
handler http.Handler
endpoints Endpoints
decoder httphelper.Decoder
corsOpts *cors.Options
logger *slog.Logger
}

func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
s.handler.ServeHTTP(w, r)
}

func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
Expand Down

0 comments on commit 7b64687

Please sign in to comment.