From 0b71230005d8e13ba10386f8a80d6a22a31c67bc Mon Sep 17 00:00:00 2001 From: Victor Elias Date: Fri, 16 Sep 2022 22:36:45 -0300 Subject: [PATCH] api: Add authorization to viewership API (#60) * api: Return a 404 on asset not found * api: Allow asset views to be fetched only by ID * api: Create consts for the param names * api/auth: Use path param on authorization instead * api: Add auth middleware to views API * api: Forward Origin header on auth request * api/auth: Forward CORS response headers back * api/auth: Fix authorization error We were sending the auth req url not the original url * api: Add Allow-Credentials to cors headers * api/auth: Rename auth request vars for clarity * api/auth: Add some logs to auth middleware Log origiunal request * api/auth: Grab original URI correctly * Revert "api/auth: Add some logs to auth middleware" This reverts commit 06d7a329ab85933ef7eb2c6c655858d2484222a7. --- api/authorization.go | 58 +++++++++++++++++++++++++++++++++----------- api/errors.go | 5 +++- api/handler.go | 22 ++++++++++++----- api/streamStatus.go | 2 +- views/client.go | 7 +++--- 5 files changed, 68 insertions(+), 26 deletions(-) diff --git a/api/authorization.go b/api/authorization.go index 05c0cf06..6cc94b50 100644 --- a/api/authorization.go +++ b/api/authorization.go @@ -14,8 +14,17 @@ import ( ) var ( - authorizationHeaders = []string{"Authorization", "Cookie"} - authTimeout = 3 * time.Second + authorizationHeaders = []string{"Authorization", "Cookie", "Origin"} + // the response headers proxied from the auth request are basically cors headers + proxiedResponseHeaders = []string{ + "Access-Control-Allow-Origin", + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Headers", + "Access-Control-Expose-Headers", + "Access-Control-Max-Age", + } + authTimeout = 3 * time.Second authRequestDuration = metrics.Factory.NewSummaryVec( prometheus.SummaryOpts{ @@ -34,33 +43,54 @@ func authorization(authUrl string) middleware { ctx, cancel := context.WithTimeout(r.Context(), authTimeout) defer cancel() - status := getStreamStatus(r) - req, err := http.NewRequestWithContext(ctx, r.Method, authUrl, nil) + authReq, err := http.NewRequestWithContext(ctx, r.Method, authUrl, nil) if err != nil { respondError(rw, http.StatusInternalServerError, err) return } - req.Header.Set("X-Original-Uri", req.URL.String()) - req.Header.Set("X-Livepeer-Stream-Id", status.ID) - for _, header := range authorizationHeaders { - req.Header[header] = r.Header[header] + authReq.Header.Set("X-Original-Uri", originalReqUri(r)) + if streamID := apiParam(r, streamIDParam); streamID != "" { + authReq.Header.Set("X-Livepeer-Stream-Id", streamID) + } else if assetID := apiParam(r, assetIDParam); assetID != "" { + authReq.Header.Set("X-Livepeer-Asset-Id", assetID) } - res, err := httpClient.Do(req) + copyHeaders(authorizationHeaders, r.Header, authReq.Header) + authRes, err := httpClient.Do(authReq) if err != nil { respondError(rw, http.StatusInternalServerError, fmt.Errorf("error authorizing request: %w", err)) return } + copyHeaders(proxiedResponseHeaders, authRes.Header, rw.Header()) - if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNoContent { - if contentType := res.Header.Get("Content-Type"); contentType != "" { + if authRes.StatusCode != http.StatusOK && authRes.StatusCode != http.StatusNoContent { + if contentType := authRes.Header.Get("Content-Type"); contentType != "" { rw.Header().Set("Content-Type", contentType) } - rw.WriteHeader(res.StatusCode) - if _, err := io.Copy(rw, res.Body); err != nil { - glog.Errorf("Error writing auth error response. err=%q, status=%d, headers=%+v", err, res.StatusCode, res.Header) + rw.WriteHeader(authRes.StatusCode) + if _, err := io.Copy(rw, authRes.Body); err != nil { + glog.Errorf("Error writing auth error response. err=%q, status=%d, headers=%+v", err, authRes.StatusCode, authRes.Header) } return } next.ServeHTTP(rw, r) }) } + +func originalReqUri(r *http.Request) string { + proto := "http" + if r.TLS != nil { + proto = "https" + } + if fwdProto := r.Header.Get("X-Forwarded-Proto"); fwdProto != "" { + proto = fwdProto + } + return fmt.Sprintf("%s://%s%s", proto, r.Host, r.URL.RequestURI()) +} + +func copyHeaders(headers []string, src, dest http.Header) { + for _, header := range headers { + if vals := src[header]; len(vals) > 0 { + dest[header] = vals + } + } +} diff --git a/api/errors.go b/api/errors.go index eedef132..1fffdedf 100644 --- a/api/errors.go +++ b/api/errors.go @@ -7,6 +7,7 @@ import ( "github.com/golang/glog" "github.com/livepeer/livepeer-data/health" + "github.com/livepeer/livepeer-data/views" ) type errorResponse struct { @@ -18,7 +19,9 @@ func respondError(rw http.ResponseWriter, defaultStatus int, errs ...error) { response := errorResponse{} for _, err := range errs { response.Errors = append(response.Errors, err.Error()) - if errors.Is(err, health.ErrStreamNotFound) || errors.Is(err, health.ErrEventNotFound) { + if errors.Is(err, health.ErrStreamNotFound) || + errors.Is(err, health.ErrEventNotFound) || + errors.Is(err, views.ErrAssetNotFound) { status = http.StatusNotFound } } diff --git a/api/handler.go b/api/handler.go index 08e60dff..d02bc383 100644 --- a/api/handler.go +++ b/api/handler.go @@ -21,6 +21,9 @@ const ( sseRetryBackoff = 10 * time.Second ssePingDelay = 20 * time.Second sseBufferSize = 128 + + streamIDParam = "streamId" + assetIDParam = "assetId" ) type APIHandlerOptions struct { @@ -54,14 +57,14 @@ func NewHandler(serverCtx context.Context, opts APIHandlerOptions, healthcore *h func addStreamHealthHandlers(router *httprouter.Router, handler *apiHandler) { healthcore, opts := handler.core, handler.opts middlewares := []middleware{ - streamStatus(healthcore, "streamId"), + streamStatus(healthcore), regionProxy(opts.RegionalHostFormat, opts.OwnRegion), } if opts.AuthURL != "" { middlewares = append(middlewares, authorization(opts.AuthURL)) } addApiHandler := func(apiPath, name string, handler http.HandlerFunc) { - fullPath := path.Join(opts.APIRoot, "/stream/:streamId", apiPath) + fullPath := path.Join(opts.APIRoot, "/stream/:"+streamIDParam, apiPath) fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler, middlewares...) router.Handler("GET", fullPath, fullHandler) } @@ -71,10 +74,13 @@ func addStreamHealthHandlers(router *httprouter.Router, handler *apiHandler) { func addViewershipHandlers(router *httprouter.Router, handler *apiHandler) { opts := handler.opts - // TODO: Add authorization to views API + middlewares := []middleware{} + if opts.AuthURL != "" { + middlewares = append(middlewares, authorization(opts.AuthURL)) + } addApiHandler := func(apiPath, name string, handler http.HandlerFunc) { - fullPath := path.Join(opts.APIRoot, "/views/:assetId", apiPath) - fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler) + fullPath := path.Join(opts.APIRoot, "/views/:"+assetIDParam, apiPath) + fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler, middlewares...) router.Handler("GET", fullPath, fullHandler) } addApiHandler("/total", "get_total_views", handler.getTotalViews) @@ -87,6 +93,10 @@ func (h *apiHandler) cors() middleware { } rw.Header().Set("Access-Control-Allow-Origin", "*") rw.Header().Set("Access-Control-Allow-Headers", "*") + if origin := r.Header.Get("Origin"); origin != "" { + rw.Header().Set("Access-Control-Allow-Origin", origin) + rw.Header().Set("Access-Control-Allow-Credentials", "true") + } next.ServeHTTP(rw, r) }) } @@ -100,7 +110,7 @@ func (h *apiHandler) healthcheck(rw http.ResponseWriter, r *http.Request) { } func (h *apiHandler) getTotalViews(rw http.ResponseWriter, r *http.Request) { - views, err := h.views.GetTotalViews(r.Context(), apiParam(r, "assetId")) + views, err := h.views.GetTotalViews(r.Context(), apiParam(r, assetIDParam)) if err != nil { respondError(rw, http.StatusInternalServerError, err) return diff --git a/api/streamStatus.go b/api/streamStatus.go index 899bf171..2b7dc402 100644 --- a/api/streamStatus.go +++ b/api/streamStatus.go @@ -14,7 +14,7 @@ const ( streamStatusKey contextKey = iota ) -func streamStatus(healthcore *health.Core, streamIDParam string) middleware { +func streamStatus(healthcore *health.Core) middleware { return inlineMiddleware(func(rw http.ResponseWriter, r *http.Request, next http.Handler) { streamID := apiParam(r, streamIDParam) if streamID == "" { diff --git a/views/client.go b/views/client.go index 1f613f6d..20dccb00 100644 --- a/views/client.go +++ b/views/client.go @@ -13,6 +13,8 @@ import ( "github.com/prometheus/common/model" ) +var ErrAssetNotFound = errors.New("asset not found") + type TotalViews struct { ID string `json:"id"` StartViews int64 `json:"startViews"` @@ -41,10 +43,7 @@ func NewClient(opts ClientOptions) (*Client, error) { func (c *Client) GetTotalViews(ctx context.Context, id string) ([]TotalViews, error) { asset, err := c.lp.GetAsset(id) if errors.Is(err, livepeer.ErrNotExists) { - asset, err = c.lp.GetAssetByPlaybackID(id, false) - } - if errors.Is(err, livepeer.ErrNotExists) { - return nil, errors.New("asset not found") + return nil, ErrAssetNotFound } else if err != nil { return nil, fmt.Errorf("error getting asset: %w", err) }