From fe41415685ecf2efe1a9f00175693c721476d697 Mon Sep 17 00:00:00 2001 From: Fergus Morrow Date: Thu, 27 Jun 2024 11:50:46 +0100 Subject: [PATCH] webhook filter: allow client redirects via 302 responses (#3130) This commit changes the behaviour of the webhook filter when a 302 Found response is recieved from the AuthN/AuthZ endpoint. As a result, it allows front-end facing (i.e. non-API) traffic to be filtered via the webhook. Documentation updates and increased test coverage is included. Incidental: Prevent the webhook client from following redirects from the AuthN/AuthZ endpoint: during testing I realised that the default `net/http` behaviour was in use - i.e. redirects were followed. Signed-off-by: Fergus Morrow --- docs/reference/filters.md | 1 + filters/auth/auth.go | 13 ++++++++++-- filters/auth/auth_test.go | 1 + filters/auth/authclient.go | 10 ++++++++- filters/auth/grantconfig.go | 1 + filters/auth/tokeninfo.go | 2 +- filters/auth/tokenintrospection.go | 2 +- filters/auth/webhook.go | 10 +++++++-- filters/auth/webhook_test.go | 33 +++++++++++++++++++++++------- proxy/proxytest/proxytest.go | 8 ++++++++ 10 files changed, 67 insertions(+), 14 deletions(-) diff --git a/docs/reference/filters.md b/docs/reference/filters.md index 026031bea0..466dfbe33b 100644 --- a/docs/reference/filters.md +++ b/docs/reference/filters.md @@ -1237,6 +1237,7 @@ headers to copy as an optional second argument to the filter. Responses from the webhook will be treated as follows: * Authorized if the status code is less than 300 +* Redirection, using the `Location` header, if the status code is 302 * Forbidden if the status code is 403 * Unauthorized for remaining status codes diff --git a/filters/auth/auth.go b/filters/auth/auth.go index b138602c79..1115b2fee3 100644 --- a/filters/auth/auth.go +++ b/filters/auth/auth.go @@ -105,6 +105,7 @@ func reject( reason rejectReason, hostname, debuginfo string, + destination string, ) { if debuginfo == "" { ctx.Logger().Debugf( @@ -125,6 +126,10 @@ func reject( Header: make(map[string][]string), } + if status == http.StatusFound && destination != "" { + rsp.Header.Add("Location", destination) + } + if hostname != "" { // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.2 rsp.Header.Add("WWW-Authenticate", hostname) @@ -133,12 +138,16 @@ func reject( ctx.Serve(rsp) } +func redirect(ctx filters.FilterContext, username string, reason rejectReason, destination, debuginfo string) { + reject(ctx, http.StatusFound, username, reason, "", debuginfo, destination) +} + func unauthorized(ctx filters.FilterContext, username string, reason rejectReason, hostname, debuginfo string) { - reject(ctx, http.StatusUnauthorized, username, reason, hostname, debuginfo) + reject(ctx, http.StatusUnauthorized, username, reason, hostname, debuginfo, "") } func forbidden(ctx filters.FilterContext, username string, reason rejectReason, debuginfo string) { - reject(ctx, http.StatusForbidden, username, reason, "", debuginfo) + reject(ctx, http.StatusForbidden, username, reason, "", debuginfo, "") } func authorized(ctx filters.FilterContext, username string) { diff --git a/filters/auth/auth_test.go b/filters/auth/auth_test.go index 3e7b6fbe37..13aa7a14ad 100644 --- a/filters/auth/auth_test.go +++ b/filters/auth/auth_test.go @@ -7,6 +7,7 @@ import ( const ( testToken = "test-token" testWebhookInvalidScopeToken = "test-webhook-invalid-scope-token" + testWebhookRedirectToken = "test-webhook-redirect" testUID = "jdoe" testScope = "test-scope" testScope2 = "test-scope2" diff --git a/filters/auth/authclient.go b/filters/auth/authclient.go index ca5463eade..b4a04e1235 100644 --- a/filters/auth/authclient.go +++ b/filters/auth/authclient.go @@ -36,7 +36,7 @@ type tokeninfoClient interface { var _ tokeninfoClient = &authClient{} -func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns int, tracer opentracing.Tracer) (*authClient, error) { +func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns int, tracer opentracing.Tracer, followRedirects bool) (*authClient, error) { if tracer == nil { tracer = opentracing.NoopTracer{} } @@ -49,6 +49,13 @@ func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns return nil, err } + var checkRedirectFn func(req *http.Request, via []*http.Request) error + if followRedirects { + checkRedirectFn = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + cli := net.NewClient(net.Options{ ResponseHeaderTimeout: timeout, TLSHandshakeTimeout: timeout, @@ -56,6 +63,7 @@ func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns Tracer: tracer, OpentracingComponentTag: "skipper", OpentracingSpanName: spanName, + CheckRedirect: checkRedirectFn, }) return &authClient{url: u, cli: cli}, nil diff --git a/filters/auth/grantconfig.go b/filters/auth/grantconfig.go index 7d216ee549..4e7744b4f3 100644 --- a/filters/auth/grantconfig.go +++ b/filters/auth/grantconfig.go @@ -179,6 +179,7 @@ func (c *OAuthConfig) Init() error { c.ConnectionTimeout, c.MaxIdleConnectionsPerHost, c.Tracer, + false, ) if err != nil { return err diff --git a/filters/auth/tokeninfo.go b/filters/auth/tokeninfo.go index a38dd37f5a..8fab083701 100644 --- a/filters/auth/tokeninfo.go +++ b/filters/auth/tokeninfo.go @@ -85,7 +85,7 @@ func (o *TokeninfoOptions) getTokeninfoClient() (tokeninfoClient, error) { func (o *TokeninfoOptions) newTokeninfoClient() (tokeninfoClient, error) { var c tokeninfoClient - c, err := newAuthClient(o.URL, tokenInfoSpanName, o.Timeout, o.MaxIdleConns, o.Tracer) + c, err := newAuthClient(o.URL, tokenInfoSpanName, o.Timeout, o.MaxIdleConns, o.Tracer, false) if err != nil { return nil, err } diff --git a/filters/auth/tokenintrospection.go b/filters/auth/tokenintrospection.go index 1495a2a4eb..f8da72c922 100644 --- a/filters/auth/tokenintrospection.go +++ b/filters/auth/tokenintrospection.go @@ -289,7 +289,7 @@ func (s *tokenIntrospectionSpec) CreateFilter(args []interface{}) (filters.Filte var ac *authClient var ok bool if ac, ok = issuerAuthClient[issuerURL]; !ok { - ac, err = newAuthClient(cfg.IntrospectionEndpoint, tokenIntrospectionSpanName, s.options.Timeout, s.options.MaxIdleConns, s.options.Tracer) + ac, err = newAuthClient(cfg.IntrospectionEndpoint, tokenIntrospectionSpanName, s.options.Timeout, s.options.MaxIdleConns, s.options.Tracer, false) if err != nil { return nil, filters.ErrInvalidFilterParameters } diff --git a/filters/auth/webhook.go b/filters/auth/webhook.go index e17711f17d..0f22cc114c 100644 --- a/filters/auth/webhook.go +++ b/filters/auth/webhook.go @@ -93,7 +93,7 @@ func (ws *webhookSpec) CreateFilter(args []interface{}) (filters.Filter, error) var ac *authClient var err error if ac, ok = webhookAuthClient[s]; !ok { - ac, err = newAuthClient(s, webhookSpanName, ws.options.Timeout, ws.options.MaxIdleConns, ws.options.Tracer) + ac, err = newAuthClient(s, webhookSpanName, ws.options.Timeout, ws.options.MaxIdleConns, ws.options.Tracer, true) if err != nil { return nil, filters.ErrInvalidFilterParameters } @@ -121,7 +121,13 @@ func (f *webhookFilter) Request(ctx filters.FilterContext) { return } - // errors, redirects, auth errors, webhook errors + // redirect + if err == nil && resp.StatusCode == http.StatusFound { + redirect(ctx, "", invalidAccess, resp.Header.Get("Location"), filters.WebhookName) + return + } + + // errors, auth errors, webhook errors if err != nil || resp.StatusCode >= 300 { unauthorized(ctx, "", invalidAccess, f.authClient.url.Hostname(), filters.WebhookName) return diff --git a/filters/auth/webhook_test.go b/filters/auth/webhook_test.go index e5720caf2f..4b31cd4455 100644 --- a/filters/auth/webhook_test.go +++ b/filters/auth/webhook_test.go @@ -15,15 +15,17 @@ import ( ) const headerToCopy = "X-Copy-Header" +const webhookRedirectLocation = "https://example.com/auth" func TestWebhook(t *testing.T) { for _, ti := range []struct { - msg string - token string - expected int - authorized bool - timeout bool - copyHeaders bool + msg string + token string + expected int + authorized bool + timeout bool + copyHeaders bool + expectRedirect bool }{{ msg: "invalid-token-should-be-unauthorized", token: "invalid-token", @@ -47,6 +49,12 @@ func TestWebhook(t *testing.T) { token: testWebhookInvalidScopeToken, expected: http.StatusForbidden, authorized: false, + }, { + msg: "auth-redirects-should-be-sent", + token: testWebhookRedirectToken, + expected: http.StatusFound, + expectRedirect: true, + authorized: false, }} { t.Run(ti.msg, func(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -78,6 +86,10 @@ func TestWebhook(t *testing.T) { tok := r.Header.Get(authHeaderName) tok = tok[len(authHeaderPrefix):] switch tok { + case testWebhookRedirectToken: + w.Header().Set("Location", webhookRedirectLocation) + w.WriteHeader(http.StatusFound) + return case testToken: w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "OK - Got token: "+tok) @@ -125,7 +137,7 @@ func TestWebhook(t *testing.T) { } req.Header.Set(authHeaderName, authHeaderPrefix+ti.token) - rsp, err := proxy.Client().Do(req) + rsp, err := proxy.ClientWithoutRedirectFollow().Do(req) if err != nil { t.Fatalf("failed to get response: %v", err) } @@ -141,6 +153,13 @@ func TestWebhook(t *testing.T) { t.Fatalf("unexpected status code: %v != %v %d %s", rsp.StatusCode, ti.expected, n, buf) } + // check that the location header is forwarded for a redirect + if ti.expectRedirect { + if loc := rsp.Header.Get("Location"); loc != webhookRedirectLocation { + t.Fatalf("expected webhook location header to be forwarded: %v != %v", loc, webhookRedirectLocation) + } + } + // check that the header was passed forward to the backend request, if it should have been if ti.authorized && ti.copyHeaders { if rsp.Header.Get(headerToCopy) != "test" { diff --git a/proxy/proxytest/proxytest.go b/proxy/proxytest/proxytest.go index 4d11f29e15..a1a7cf5036 100644 --- a/proxy/proxytest/proxytest.go +++ b/proxy/proxytest/proxytest.go @@ -123,6 +123,14 @@ func (p *TestProxy) Client() *TestClient { return &TestClient{p.server.Client()} } +func (p *TestProxy) ClientWithoutRedirectFollow() *TestClient { + client := p.server.Client() + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + return &TestClient{client} +} + func (p *TestProxy) Close() error { p.Log.Close() if p.dc != nil {