Skip to content

Commit

Permalink
webhook filter: allow client redirects via 302 responses (zalando#3130)
Browse files Browse the repository at this point in the history
This commit changes the behaviour of the webhook filter when a 302 Found
response is recieved from the AuthN/AuthZ endpoint. As a result, it allows
front-end facing (i.e. non-API) traffic to be filtered via the webhook.

Documentation updates and increased test coverage is included.

Incidental: Prevent the webhook client from following redirects from the
AuthN/AuthZ endpoint: during testing I realised that the default `net/http`
behaviour was in use - i.e. redirects were followed.

Signed-off-by: Fergus Morrow <fergus@ometria.com>
  • Loading branch information
FergusInLondon committed Jun 27, 2024
1 parent 4fee4d3 commit fe41415
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/reference/filters.md
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,7 @@ headers to copy as an optional second argument to the filter.
Responses from the webhook will be treated as follows:

* Authorized if the status code is less than 300
* Redirection, using the `Location` header, if the status code is 302
* Forbidden if the status code is 403
* Unauthorized for remaining status codes

Expand Down
13 changes: 11 additions & 2 deletions filters/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func reject(
reason rejectReason,
hostname,
debuginfo string,
destination string,
) {
if debuginfo == "" {
ctx.Logger().Debugf(
Expand All @@ -125,6 +126,10 @@ func reject(
Header: make(map[string][]string),
}

if status == http.StatusFound && destination != "" {
rsp.Header.Add("Location", destination)
}

if hostname != "" {
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.2
rsp.Header.Add("WWW-Authenticate", hostname)
Expand All @@ -133,12 +138,16 @@ func reject(
ctx.Serve(rsp)
}

func redirect(ctx filters.FilterContext, username string, reason rejectReason, destination, debuginfo string) {
reject(ctx, http.StatusFound, username, reason, "", debuginfo, destination)
}

func unauthorized(ctx filters.FilterContext, username string, reason rejectReason, hostname, debuginfo string) {
reject(ctx, http.StatusUnauthorized, username, reason, hostname, debuginfo)
reject(ctx, http.StatusUnauthorized, username, reason, hostname, debuginfo, "")
}

func forbidden(ctx filters.FilterContext, username string, reason rejectReason, debuginfo string) {
reject(ctx, http.StatusForbidden, username, reason, "", debuginfo)
reject(ctx, http.StatusForbidden, username, reason, "", debuginfo, "")
}

func authorized(ctx filters.FilterContext, username string) {
Expand Down
1 change: 1 addition & 0 deletions filters/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
const (
testToken = "test-token"
testWebhookInvalidScopeToken = "test-webhook-invalid-scope-token"
testWebhookRedirectToken = "test-webhook-redirect"
testUID = "jdoe"
testScope = "test-scope"
testScope2 = "test-scope2"
Expand Down
10 changes: 9 additions & 1 deletion filters/auth/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type tokeninfoClient interface {

var _ tokeninfoClient = &authClient{}

func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns int, tracer opentracing.Tracer) (*authClient, error) {
func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns int, tracer opentracing.Tracer, followRedirects bool) (*authClient, error) {
if tracer == nil {
tracer = opentracing.NoopTracer{}
}
Expand All @@ -49,13 +49,21 @@ func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns
return nil, err
}

var checkRedirectFn func(req *http.Request, via []*http.Request) error
if followRedirects {
checkRedirectFn = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}

cli := net.NewClient(net.Options{
ResponseHeaderTimeout: timeout,
TLSHandshakeTimeout: timeout,
MaxIdleConnsPerHost: maxIdleConns,
Tracer: tracer,
OpentracingComponentTag: "skipper",
OpentracingSpanName: spanName,
CheckRedirect: checkRedirectFn,
})

return &authClient{url: u, cli: cli}, nil
Expand Down
1 change: 1 addition & 0 deletions filters/auth/grantconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func (c *OAuthConfig) Init() error {
c.ConnectionTimeout,
c.MaxIdleConnectionsPerHost,
c.Tracer,
false,
)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion filters/auth/tokeninfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (o *TokeninfoOptions) getTokeninfoClient() (tokeninfoClient, error) {
func (o *TokeninfoOptions) newTokeninfoClient() (tokeninfoClient, error) {
var c tokeninfoClient

c, err := newAuthClient(o.URL, tokenInfoSpanName, o.Timeout, o.MaxIdleConns, o.Tracer)
c, err := newAuthClient(o.URL, tokenInfoSpanName, o.Timeout, o.MaxIdleConns, o.Tracer, false)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion filters/auth/tokenintrospection.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (s *tokenIntrospectionSpec) CreateFilter(args []interface{}) (filters.Filte
var ac *authClient
var ok bool
if ac, ok = issuerAuthClient[issuerURL]; !ok {
ac, err = newAuthClient(cfg.IntrospectionEndpoint, tokenIntrospectionSpanName, s.options.Timeout, s.options.MaxIdleConns, s.options.Tracer)
ac, err = newAuthClient(cfg.IntrospectionEndpoint, tokenIntrospectionSpanName, s.options.Timeout, s.options.MaxIdleConns, s.options.Tracer, false)
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
Expand Down
10 changes: 8 additions & 2 deletions filters/auth/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (ws *webhookSpec) CreateFilter(args []interface{}) (filters.Filter, error)
var ac *authClient
var err error
if ac, ok = webhookAuthClient[s]; !ok {
ac, err = newAuthClient(s, webhookSpanName, ws.options.Timeout, ws.options.MaxIdleConns, ws.options.Tracer)
ac, err = newAuthClient(s, webhookSpanName, ws.options.Timeout, ws.options.MaxIdleConns, ws.options.Tracer, true)
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
Expand Down Expand Up @@ -121,7 +121,13 @@ func (f *webhookFilter) Request(ctx filters.FilterContext) {
return
}

// errors, redirects, auth errors, webhook errors
// redirect
if err == nil && resp.StatusCode == http.StatusFound {
redirect(ctx, "", invalidAccess, resp.Header.Get("Location"), filters.WebhookName)
return
}

// errors, auth errors, webhook errors
if err != nil || resp.StatusCode >= 300 {
unauthorized(ctx, "", invalidAccess, f.authClient.url.Hostname(), filters.WebhookName)
return
Expand Down
33 changes: 26 additions & 7 deletions filters/auth/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import (
)

const headerToCopy = "X-Copy-Header"
const webhookRedirectLocation = "https://example.com/auth"

func TestWebhook(t *testing.T) {
for _, ti := range []struct {
msg string
token string
expected int
authorized bool
timeout bool
copyHeaders bool
msg string
token string
expected int
authorized bool
timeout bool
copyHeaders bool
expectRedirect bool
}{{
msg: "invalid-token-should-be-unauthorized",
token: "invalid-token",
Expand All @@ -47,6 +49,12 @@ func TestWebhook(t *testing.T) {
token: testWebhookInvalidScopeToken,
expected: http.StatusForbidden,
authorized: false,
}, {
msg: "auth-redirects-should-be-sent",
token: testWebhookRedirectToken,
expected: http.StatusFound,
expectRedirect: true,
authorized: false,
}} {
t.Run(ti.msg, func(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -78,6 +86,10 @@ func TestWebhook(t *testing.T) {
tok := r.Header.Get(authHeaderName)
tok = tok[len(authHeaderPrefix):]
switch tok {
case testWebhookRedirectToken:
w.Header().Set("Location", webhookRedirectLocation)
w.WriteHeader(http.StatusFound)
return
case testToken:
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "OK - Got token: "+tok)
Expand Down Expand Up @@ -125,7 +137,7 @@ func TestWebhook(t *testing.T) {
}
req.Header.Set(authHeaderName, authHeaderPrefix+ti.token)

rsp, err := proxy.Client().Do(req)
rsp, err := proxy.ClientWithoutRedirectFollow().Do(req)
if err != nil {
t.Fatalf("failed to get response: %v", err)
}
Expand All @@ -141,6 +153,13 @@ func TestWebhook(t *testing.T) {
t.Fatalf("unexpected status code: %v != %v %d %s", rsp.StatusCode, ti.expected, n, buf)
}

// check that the location header is forwarded for a redirect
if ti.expectRedirect {
if loc := rsp.Header.Get("Location"); loc != webhookRedirectLocation {
t.Fatalf("expected webhook location header to be forwarded: %v != %v", loc, webhookRedirectLocation)
}
}

// check that the header was passed forward to the backend request, if it should have been
if ti.authorized && ti.copyHeaders {
if rsp.Header.Get(headerToCopy) != "test" {
Expand Down
8 changes: 8 additions & 0 deletions proxy/proxytest/proxytest.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ func (p *TestProxy) Client() *TestClient {
return &TestClient{p.server.Client()}
}

func (p *TestProxy) ClientWithoutRedirectFollow() *TestClient {
client := p.server.Client()
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return &TestClient{client}
}

func (p *TestProxy) Close() error {
p.Log.Close()
if p.dc != nil {
Expand Down

0 comments on commit fe41415

Please sign in to comment.