Skip to content

Commit

Permalink
Add support to restrict localserver response handling to specific ori…
Browse files Browse the repository at this point in the history
…gins (#1641)
  • Loading branch information
directionless authored Mar 7, 2024
1 parent 10a4aee commit 8f63010
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 7 deletions.
50 changes: 45 additions & 5 deletions ee/localserver/krypto-ec-middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand All @@ -20,7 +21,6 @@ import (
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/kolide/launcher/pkg/traces"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
)

const (
Expand All @@ -35,6 +35,7 @@ type v2CmdRequestType struct {
Body []byte
CallbackUrl string
CallbackHeaders map[string][]string
AllowedOrigins []string
}

func (cmdReq v2CmdRequestType) CallbackReq() (*http.Request, error) {
Expand Down Expand Up @@ -79,8 +80,9 @@ func newKryptoEcMiddleware(slogger *slog.Logger, localDbSigner, hardwareSigner c
type callbackErrors string

const (
timeOutOfRangeErr callbackErrors = "time-out-of-range"
responseFailureErr callbackErrors = "response-failure"
timeOutOfRangeErr callbackErrors = "time-out-of-range"
responseFailureErr callbackErrors = "response-failure"
originDisallowedErr callbackErrors = "origin-disallowed"
)

type callbackDataStruct struct {
Expand All @@ -96,7 +98,7 @@ type callbackDataStruct struct {
// Also, because the URL is the box, we cannot cleanly do this through middleware. It reqires a lot of passing data
// around through context. Doing it here, as part of kryptoEcMiddleware, allows for a fairly succint defer.
//
// Note that this should be a goroutine.
// Note that because this is a network call, it should be called in a goroutine.
func (e *kryptoEcMiddleware) sendCallback(req *http.Request, data *callbackDataStruct) {
if req == nil {
return
Expand Down Expand Up @@ -216,12 +218,47 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {
}()
}

// Check if the origin is in the allowed list. See https://github.com/kolide/k2/issues/9634
if len(cmdReq.AllowedOrigins) > 0 {
allowed := false
for _, ao := range cmdReq.AllowedOrigins {
if strings.EqualFold(ao, r.Header.Get("Origin")) {
allowed = true
break
}
}

if !allowed {
span.SetAttributes(attribute.String("origin", r.Header.Get("Origin")))
traces.SetError(span, fmt.Errorf("origin %s is not allowed", r.Header.Get("Origin")))
e.slogger.Log(r.Context(), slog.LevelError,
"origin is not allowed",
"allowlist", cmdReq.AllowedOrigins,
"origin", r.Header.Get("Origin"),
)

w.WriteHeader(http.StatusUnauthorized)
callbackData.Error = originDisallowedErr
return
}

e.slogger.Log(r.Context(), slog.LevelDebug,
"origin matches allowlist",
"origin", r.Header.Get("Origin"),
)
} else {
e.slogger.Log(r.Context(), slog.LevelDebug,
"origin is allowed by default, no allowlist",
"origin", r.Header.Get("Origin"),
)
}

// Check the timestamp, this prevents people from saving a challenge and then
// reusing it a bunch. However, it will fail if the clocks are too far out of sync.
timestampDelta := time.Now().Unix() - challengeBox.Timestamp()
if timestampDelta > timestampValidityRange || timestampDelta < -timestampValidityRange {
span.SetAttributes(attribute.Int64("timestamp_delta", timestampDelta))
span.SetStatus(codes.Error, "timestamp is out of range")
traces.SetError(span, errors.New("timestamp is out of range"))
e.slogger.Log(r.Context(), slog.LevelError,
"timestamp is out of range",
"delta", timestampDelta,
Expand All @@ -234,13 +271,16 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {

newReq := &http.Request{
Method: http.MethodPost,
Header: make(http.Header),
URL: &url.URL{
Scheme: r.URL.Scheme,
Host: r.Host,
Path: cmdReq.Path,
},
}

newReq.Header.Set("Origin", r.Header.Get("Origin"))

// setting the newReq context to the current request context
// allows the trace to continue to the inner request,
// maintains the same lifetime as the original request,
Expand Down
147 changes: 147 additions & 0 deletions ee/localserver/krypto-ec-middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,153 @@ func TestKryptoEcMiddleware(t *testing.T) {
}
}

func Test_AllowedOrigin(t *testing.T) {
t.Parallel()

counterpartyKey, err := echelper.GenerateEcdsaKey()
require.NoError(t, err)

challengeId := []byte(ulid.New())
challengeData := []byte(ulid.New())

var tests = []struct {
name string
requestOrigin string
allowedOrigins []string
logStr string
expectedStatus int
}{
{
name: "no allowed specified",
requestOrigin: "https://auth.example.com",
expectedStatus: http.StatusOK,
logStr: "origin is allowed by default",
},
{
name: "no allowed specified missing origin",
expectedStatus: http.StatusOK,
logStr: "origin is allowed by default",
},
{
name: "allowed specified missing origin",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "allowed specified origin mismatch",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://not-it.example.com",
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "scheme mismatch",
allowedOrigins: []string{"https://auth.example.com"},
requestOrigin: "http://auth.example.com",
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "allowed specified origin matches",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://auth.example.com",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
{
name: "allowed specified origin matches 2",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://login.example.com",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
{
name: "allowed specified origin matches casing",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://AuTh.ExAmPlE.cOm",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

cmdReqBody := []byte(randomStringWithSqlCharacters(t, 100000))

cmdReq := v2CmdRequestType{
Path: "whatevs",
Body: cmdReqBody,
AllowedOrigins: tt.allowedOrigins,
}

challengeBytes, privateEncryptionKey, err := challenge.Generate(counterpartyKey, challengeId, challengeData, mustMarshal(t, cmdReq))
require.NoError(t, err)
encodedChallenge := base64.StdEncoding.EncodeToString(challengeBytes)

responseData := []byte(ulid.New())

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqBodyRaw, err := io.ReadAll(r.Body)
require.NoError(t, err)
defer r.Body.Close()

require.Equal(t, cmdReqBody, reqBodyRaw)
w.Write(responseData)
})

var logBytes bytes.Buffer
slogger := multislogger.New(slog.NewTextHandler(&logBytes, &slog.HandlerOptions{
Level: slog.LevelDebug,
})).Logger

// set up middlewares
kryptoEcMiddleware := newKryptoEcMiddleware(slogger, ecdsaKey(t), nil, counterpartyKey.PublicKey)
require.NoError(t, err)

h := kryptoEcMiddleware.Wrap(testHandler)

req := makeGetRequest(t, encodedChallenge)
req.Header.Set("origin", tt.requestOrigin)

rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

// FIXME: add some log string tests
//spew.Dump(logBytes.String(), tt.logStr)

require.Equal(t, tt.expectedStatus, rr.Code)

if tt.logStr != "" {
assert.Contains(t, logBytes.String(), tt.logStr)
}

if tt.expectedStatus != http.StatusOK {
return
}

// try to open the response
returnedResponseBytes, err := base64.StdEncoding.DecodeString(rr.Body.String())
require.NoError(t, err)

responseUnmarshalled, err := challenge.UnmarshalResponse(returnedResponseBytes)
require.NoError(t, err)
require.Equal(t, challengeId, responseUnmarshalled.ChallengeId)

opened, err := responseUnmarshalled.Open(*privateEncryptionKey)
require.NoError(t, err)
require.Equal(t, challengeData, opened.ChallengeData)
require.Equal(t, responseData, opened.ResponseData)
require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5)

})
}

}

func ecdsaKey(t *testing.T) *ecdsa.PrivateKey {
key, err := echelper.GenerateEcdsaKey()
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions ee/localserver/request-id.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type (
Nonce string
Timestamp time.Time
Status status
Origin string
}

status struct {
Expand Down Expand Up @@ -78,6 +79,7 @@ func (ls *localServer) requestIdHandlerFunc(w http.ResponseWriter, r *http.Reque
response := requestIdsResponse{
Nonce: ulid.New(),
Timestamp: time.Now(),
Origin: r.Header.Get("Origin"),
Status: status{
EnrollmentStatus: string(enrollmentStatus),
},
Expand Down
6 changes: 4 additions & 2 deletions ee/localserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ func (ls *localServer) startListener() (net.Listener, error) {

func (ls *localServer) preflightCorsHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Think harder, maybe?
// https://stackoverflow.com/questions/12830095/setting-http-headers
// We don't believe we can meaningfully enforce a CORS style check here -- those are enforced by the browser.
// And we recognize there are some patterns that bypass the browsers CORS enforcement. However, we do implement
// origin enforcement as an allowlist inside kryptoEcMiddleware
// See https://github.com/kolide/k2/issues/9634
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
Expand Down

0 comments on commit 8f63010

Please sign in to comment.