Skip to content

Commit

Permalink
feat: pre/post auth policy
Browse files Browse the repository at this point in the history
  • Loading branch information
lsjostro committed Feb 23, 2024
1 parent 6fab342 commit 7368bc7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 63 deletions.
46 changes: 39 additions & 7 deletions authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"log/slog"
"maps"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -93,33 +94,42 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Unauthorized, nil, nil, "no header matches any auth provider")), nil
}

span.AddEvent("provider",
span.AddEvent("provider config",
trace.WithAttributes(
attribute.String("issuer_url", provider.IssuerURL),
attribute.String("client_id", provider.ClientID),
attribute.String("callback_uri", provider.CallbackURI),
attribute.String("cookie_name_prefix", provider.CookieNamePrefix),
attribute.String("opa_policy", provider.OPAPolicy),
attribute.String("pre_auth_policy", provider.PreAuthPolicy),
attribute.String("post_auth_policy", provider.PostAuthPolicy),
attribute.Bool("secure_cookie", provider.SecureCookie),
attribute.StringSlice("scopes", provider.Scopes),
attribute.String("header_match_name", provider.HeaderMatch.Name),
),
)

// if OPA Policy is defined evaluate the request
if provider.OPAPolicy != "" {
allowed, err := policy.Eval(ctx, req.Msg, provider.OPAPolicy)
var reqInput map[string]interface{}
if provider.PreAuthPolicy != "" {
input, err := policy.RequestOrResponseToInput(req.Msg)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}

allowed, err := policy.Eval(ctx, input, provider.PreAuthPolicy)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}

if !allowed {
slog.Debug("OPA Policy denied the request")
span.SetStatus(codes.Error, "OPA Policy denied request")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Forbidden, nil, nil, "OPA Policy denied request")), nil
span.SetStatus(codes.Error, "PreAuth policy denied request")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Forbidden, nil, nil, "PreAuth policy denied request")), nil
}
reqInput = input
}

resp, err := s.authProcess(ctx, httpReq, provider)
Expand All @@ -130,6 +140,28 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
resp = s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())
}

if provider.PostAuthPolicy != "" && resp.GetStatus().GetCode() == int32(rpc.OK) {
respInput, err := policy.RequestOrResponseToInput(resp)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}

// Merge response with request input
maps.Copy(respInput, reqInput)
allowed, err := policy.Eval(ctx, respInput, provider.PostAuthPolicy)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}
if !allowed {
span.SetStatus(codes.Error, "PostAuth Policy denied request")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Forbidden, nil, nil, "PostAuth policy denied request")), nil
}
}

// Return response to envoy
span.SetStatus(codes.Ok, "success")
return connect.NewResponse(resp), nil
Expand Down
3 changes: 2 additions & 1 deletion authz/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ type OIDCProvider struct {
Scopes []string `yaml:"scopes"`
CookieNamePrefix string `yaml:"cookieNamePrefix"`
SecureCookie bool `yaml:"secureCookie"`
OPAPolicy string `yaml:"opaPolicy"`
PreAuthPolicy string `yaml:"preAuthPolicy"`
PostAuthPolicy string `yaml:"postAuthPolicy"`
HeaderMatch HeaderMatch `yaml:"headerMatch"`
}

Expand Down
80 changes: 28 additions & 52 deletions policy/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"net/url"
"strings"

auth "buf.build/gen/go/envoyproxy/envoy/protocolbuffers/go/envoy/service/auth/v3"
Expand All @@ -22,17 +21,10 @@ var (
)

// Eval evaluates the policy with the given input and returns the result.
func Eval(ctx context.Context, req *auth.CheckRequest, policy string) (bool, error) {
func Eval(ctx context.Context, input map[string]interface{}, policy string) (bool, error) {
ctx, span := tracer.Start(ctx, "PolicyEval")
defer span.End()

input, err := requestToInput(req)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return false, err
}

q, err := rego.New(
rego.Query("data.authz.allow"),
rego.Module("opaPolicy", policy),
Expand Down Expand Up @@ -69,52 +61,36 @@ func Eval(ctx context.Context, req *auth.CheckRequest, policy string) (bool, err
return true, nil
}

func requestToInput(req *auth.CheckRequest) (map[string]interface{}, error) {
func RequestOrResponseToInput(req any) (map[string]interface{}, error) {
var input map[string]interface{}
var bs []byte

bs, err := protojson.Marshal(req)
if err != nil {
return nil, err
}

err = util.UnmarshalJSON(bs, &input)
if err != nil {
return nil, err
}

path := req.GetAttributes().GetRequest().GetHttp().GetPath()
parsedPath, parsedQuery, err := getParsedPathAndQuery(path)
if err != nil {
return nil, err
}

input["parsed_path"] = parsedPath
input["parsed_query"] = parsedQuery

return input, nil
}

func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{}, error) {
parsedURL, err := url.Parse(path)
if err != nil {
return nil, nil, err
}

parsedPath := strings.Split(strings.TrimLeft(parsedURL.Path, "/"), "/")
parsedPathInterface := make([]interface{}, len(parsedPath))
for i, v := range parsedPath {
parsedPathInterface[i] = v
}

parsedQueryInterface := make(map[string]interface{})
for paramKey, paramValues := range parsedURL.Query() {
queryValues := make([]interface{}, len(paramValues))
for i, v := range paramValues {
queryValues[i] = v
// type switch for CheckRequest or CheckResponse
switch v := req.(type) {
case *auth.CheckRequest:
bs, err := protojson.Marshal(v)
if err != nil {
return nil, err
}
err = util.UnmarshalJSON(bs, &input)
if err != nil {
return nil, err
}
case *auth.CheckResponse:
bs, err := protojson.Marshal(v)
if err != nil {
return nil, err
}
err = util.UnmarshalJSON(bs, &input)
if err != nil {
return nil, err
}
for _, h := range v.GetOkResponse().GetHeaders() {
if h.GetHeader().GetKey() == "Authorization" {
input["parsed_jwt"] = strings.Split(h.GetHeader().GetValue(), " ")[1]
break
}
}
parsedQueryInterface[paramKey] = queryValues
}

return parsedPathInterface, parsedQueryInterface, nil
return input, nil
}
19 changes: 16 additions & 3 deletions run/config/providers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ providers:
# clientSecret: test1234 # omit for PKCE auth
cookieNamePrefix: podinfo
# secureCookie: true # disable for local development
opaPolicy: |
preAuthPolicy: |
package authz
import input.attributes.request.http
import rego.v1
import input.attributes.request.http
default allow = false
Expand All @@ -34,6 +33,20 @@ providers:
http.method == "GET"
glob.match("/", ["/"], http.path)
}
postAuthPolicy: |
package authz
import rego.v1
default allow = false
allow if {
token.payload.email == "kilgore@kilgore.trout"
}
token := { "payload": payload } if {
[_, payload, _] := io.jwt.decode(input.parsed_jwt)
}
scopes:
- openid
- profile
Expand Down

0 comments on commit 7368bc7

Please sign in to comment.