diff --git a/ee/desktop/runner/runner.go b/ee/desktop/runner/runner.go index bd45a510d..8a95be32d 100644 --- a/ee/desktop/runner/runner.go +++ b/ee/desktop/runner/runner.go @@ -29,6 +29,7 @@ import ( "github.com/kolide/launcher/ee/desktop/user/client" "github.com/kolide/launcher/ee/desktop/user/menu" "github.com/kolide/launcher/ee/desktop/user/notify" + "github.com/kolide/launcher/ee/presencedetection" "github.com/kolide/launcher/ee/ui/assets" "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/rungroup" @@ -282,26 +283,25 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { ) } -func (r *DesktopUsersProcessesRunner) DetectPresence(reason string, interval time.Duration) (bool, error) { +func (r *DesktopUsersProcessesRunner) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { if r.uidProcs == nil || len(r.uidProcs) == 0 { - return false, errors.New("no desktop processes running") + return presencedetection.DetectionFailedDurationValue, errors.New("no desktop processes running") } var lastErr error for _, proc := range r.uidProcs { client := client.New(r.userServerAuthToken, proc.socketPath) - success, err := client.DetectPresence(reason, interval) + durationSinceLastDetection, err := client.DetectPresence(reason, interval) - // not sure how to handle the possiblity of multiple users - // so just return the first success - if success { - return success, err + if err != nil { + lastErr = err + continue } - lastErr = err + return durationSinceLastDetection, nil } - return false, fmt.Errorf("no desktop processes detected presence, last error: %w", lastErr) + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("no desktop processes detected presence, last error: %w", lastErr) } // killDesktopProcesses kills any existing desktop processes diff --git a/ee/desktop/user/client/client.go b/ee/desktop/user/client/client.go index 6c82358ae..be7f83d5d 100644 --- a/ee/desktop/user/client/client.go +++ b/ee/desktop/user/client/client.go @@ -12,6 +12,7 @@ import ( "github.com/kolide/launcher/ee/desktop/user/notify" "github.com/kolide/launcher/ee/desktop/user/server" + "github.com/kolide/launcher/ee/presencedetection" ) type transport struct { @@ -62,13 +63,13 @@ func (c *client) ShowDesktop() error { return c.get("show") } -func (c *client) DetectPresence(reason string, interval time.Duration) (bool, error) { +func (c *client) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { encodedReason := url.QueryEscape(reason) encodedInterval := url.QueryEscape(interval.String()) resp, requestErr := c.base.Get(fmt.Sprintf("http://unix/detect_presence?reason=%s&interval=%s", encodedReason, encodedInterval)) if requestErr != nil { - return false, fmt.Errorf("getting presence: %w", requestErr) + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("getting presence: %w", requestErr) } var response server.DetectPresenceResponse @@ -76,7 +77,7 @@ func (c *client) DetectPresence(reason string, interval time.Duration) (bool, er defer resp.Body.Close() if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - return false, fmt.Errorf("decoding response: %w", err) + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("decoding response: %w", err) } } @@ -85,7 +86,12 @@ func (c *client) DetectPresence(reason string, interval time.Duration) (bool, er err = errors.New(response.Error) } - return response.Success, err + durationSinceLastDetection, parseErr := time.ParseDuration(response.DurationSinceLastDetection) + if parseErr != nil { + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("parsing time since last detection: %w", parseErr) + } + + return durationSinceLastDetection, err } func (c *client) Notify(n notify.Notification) error { diff --git a/ee/desktop/user/server/server.go b/ee/desktop/user/server/server.go index 8efa2c6e6..fd674a4ce 100644 --- a/ee/desktop/user/server/server.go +++ b/ee/desktop/user/server/server.go @@ -174,8 +174,8 @@ func (s *UserServer) showDesktop(w http.ResponseWriter, req *http.Request) { } type DetectPresenceResponse struct { - Success bool `json:"success"` - Error string `json:"error,omitempty"` + DurationSinceLastDetection string `json:"duration_since_last_detection,omitempty"` + Error string `json:"error,omitempty"` } func (s *UserServer) detectPresence(w http.ResponseWriter, req *http.Request) { @@ -201,13 +201,20 @@ func (s *UserServer) detectPresence(w http.ResponseWriter, req *http.Request) { } // detect presence - success, err := s.presenceDetector.DetectPresence(reason, interval) + durationSinceLastDetection, err := s.presenceDetector.DetectPresence(reason, interval) response := DetectPresenceResponse{ - Success: success, + DurationSinceLastDetection: durationSinceLastDetection.String(), } if err != nil { response.Error = err.Error() + + s.slogger.Log(context.TODO(), slog.LevelDebug, + "detecting presence", + "reason", reason, + "interval", interval, + "err", err, + ) } // convert response to json diff --git a/ee/localserver/krypto-ec-middleware.go b/ee/localserver/krypto-ec-middleware.go index ca0f6245d..028d28bf3 100644 --- a/ee/localserver/krypto-ec-middleware.go +++ b/ee/localserver/krypto-ec-middleware.go @@ -25,12 +25,13 @@ import ( ) const ( - timestampValidityRange = 150 - kolideKryptoEccHeader20230130Value = "2023-01-30" - kolideKryptoHeaderKey = "X-Kolide-Krypto" - kolideSessionIdHeaderKey = "X-Kolide-Session" - kolidePresenceDetectionInterval = "X-Kolide-Presence-Detection-Interval" - kolidePresenceDetectionReason = "X-Kolide-Presence-Detection-Reason" + timestampValidityRange = 150 + kolideKryptoEccHeader20230130Value = "2023-01-30" + kolideKryptoHeaderKey = "X-Kolide-Krypto" + kolideSessionIdHeaderKey = "X-Kolide-Session" + kolidePresenceDetectionInterval = "X-Kolide-Presence-Detection-Interval" + kolidePresenceDetectionReason = "X-Kolide-Presence-Detection-Reason" + kolideDurationSinceLastPresenceDetection = "X-Kolide-Duration-Since-Last-Presence-Detection" ) type v2CmdRequestType struct { diff --git a/ee/localserver/mocks/presenceDetector.go b/ee/localserver/mocks/presenceDetector.go index f3a08c477..0af0518f7 100644 --- a/ee/localserver/mocks/presenceDetector.go +++ b/ee/localserver/mocks/presenceDetector.go @@ -14,22 +14,22 @@ type PresenceDetector struct { } // DetectPresence provides a mock function with given fields: reason, interval -func (_m *PresenceDetector) DetectPresence(reason string, interval time.Duration) (bool, error) { +func (_m *PresenceDetector) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { ret := _m.Called(reason, interval) if len(ret) == 0 { panic("no return value specified for DetectPresence") } - var r0 bool + var r0 time.Duration var r1 error - if rf, ok := ret.Get(0).(func(string, time.Duration) (bool, error)); ok { + if rf, ok := ret.Get(0).(func(string, time.Duration) (time.Duration, error)); ok { return rf(reason, interval) } - if rf, ok := ret.Get(0).(func(string, time.Duration) bool); ok { + if rf, ok := ret.Get(0).(func(string, time.Duration) time.Duration); ok { r0 = rf(reason, interval) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Get(0).(time.Duration) } if rf, ok := ret.Get(1).(func(string, time.Duration) error); ok { diff --git a/ee/localserver/presence-detection-middleware_test.go b/ee/localserver/presence-detection-middleware_test.go index 1d9de7c0c..d00e8d3f6 100644 --- a/ee/localserver/presence-detection-middleware_test.go +++ b/ee/localserver/presence-detection-middleware_test.go @@ -4,8 +4,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/kolide/launcher/ee/localserver/mocks" + "github.com/kolide/launcher/ee/presencedetection" + "github.com/kolide/launcher/pkg/log/multislogger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -15,16 +18,18 @@ func TestPresenceDetectionHandler(t *testing.T) { t.Parallel() tests := []struct { - name string - expectDetectPresenceCall bool - intervalHeader, reasonHeader string - presenceDetectionSuccess bool - presenceDetectionError error - expectedStatusCode int + name string + expectDetectPresenceCall bool + intervalHeader, reasonHeader string + durationSinceLastDetection time.Duration + presenceDetectionError error + shouldHavePresenceDetectionDurationResponseHeader bool + expectedStatusCode int }{ { name: "no presence detection headers", expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: false, }, { name: "invalid presence detection interval", @@ -32,29 +37,32 @@ func TestPresenceDetectionHandler(t *testing.T) { expectedStatusCode: http.StatusBadRequest, }, { - name: "valid presence detection, detection fails", - expectDetectPresenceCall: true, - intervalHeader: "10s", - reasonHeader: "test reason", - presenceDetectionSuccess: false, - expectedStatusCode: http.StatusUnauthorized, + name: "valid presence detection, detection fails", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: presencedetection.DetectionFailedDurationValue, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, }, { - name: "valid presence detection, detection succeeds", - expectDetectPresenceCall: true, - intervalHeader: "10s", - reasonHeader: "test reason", - presenceDetectionSuccess: true, - expectedStatusCode: http.StatusOK, + name: "valid presence detection, detection succeeds", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: 0, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, }, { - name: "presence detection error", - expectDetectPresenceCall: true, - intervalHeader: "10s", - reasonHeader: "test reason", - presenceDetectionSuccess: false, - presenceDetectionError: assert.AnError, - expectedStatusCode: http.StatusUnauthorized, + name: "presence detection error", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: presencedetection.DetectionFailedDurationValue, + presenceDetectionError: assert.AnError, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, }, } @@ -66,11 +74,12 @@ func TestPresenceDetectionHandler(t *testing.T) { mockPresenceDetector := mocks.NewPresenceDetector(t) if tt.expectDetectPresenceCall { - mockPresenceDetector.On("DetectPresence", mock.AnythingOfType("string"), mock.AnythingOfType("Duration")).Return(tt.presenceDetectionSuccess, tt.presenceDetectionError) + mockPresenceDetector.On("DetectPresence", mock.AnythingOfType("string"), mock.AnythingOfType("Duration")).Return(tt.durationSinceLastDetection, tt.presenceDetectionError) } server := &localServer{ presenceDetector: mockPresenceDetector, + slogger: multislogger.NewNopLogger(), } // Create a test handler for the middleware to call @@ -94,6 +103,9 @@ func TestPresenceDetectionHandler(t *testing.T) { rr := httptest.NewRecorder() handlerToTest.ServeHTTP(rr, req) + if tt.shouldHavePresenceDetectionDurationResponseHeader { + require.NotEmpty(t, rr.Header().Get(kolideDurationSinceLastPresenceDetection)) + } require.Equal(t, tt.expectedStatusCode, rr.Code) }) } diff --git a/ee/localserver/server.go b/ee/localserver/server.go index ea689fe5e..812268eaf 100644 --- a/ee/localserver/server.go +++ b/ee/localserver/server.go @@ -12,6 +12,7 @@ import ( "log/slog" "net" "net/http" + "runtime" "strings" "time" @@ -66,7 +67,7 @@ const ( ) type presenceDetector interface { - DetectPresence(reason string, interval time.Duration) (bool, error) + DetectPresence(reason string, interval time.Duration) (time.Duration, error) } func New(ctx context.Context, k types.Knapsack, presenceDetector presenceDetector) (*localServer, error) { @@ -127,7 +128,7 @@ func New(ctx context.Context, k types.Knapsack, presenceDetector presenceDetecto // curl localhost:40978/acceleratecontrol --data '{"interval":"250ms", "duration":"1s"}' // mux.Handle("/acceleratecontrol", ls.requestAccelerateControlHandler()) // curl localhost:40978/id - // mux.Handle("/id", ls.requestIdHandler()) + mux.Handle("/id", ls.requestIdHandler()) srv := &http.Server{ Handler: otelhttp.NewHandler( @@ -411,6 +412,12 @@ func (ls *localServer) rateLimitHandler(next http.Handler) http.Handler { func (ls *localServer) presenceDetectionHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if runtime.GOOS != "darwin" { + next.ServeHTTP(w, r) + return + } + // can test this by adding an unauthed endpoint to the mux and running, for example: // curl -H "X-Kolide-Presence-Detection-Interval: 10s" -H "X-Kolide-Presence-Detection-Reason: my reason" localhost:12519/id detectionIntervalStr := r.Header.Get(kolidePresenceDetectionInterval) @@ -422,6 +429,8 @@ func (ls *localServer) presenceDetectionHandler(next http.Handler) http.Handler detectionIntervalDuration, err := time.ParseDuration(detectionIntervalStr) if err != nil { + // this is the only time this should returna non-200 status code + // asked for presence detection, but the interval is invalid http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -433,17 +442,21 @@ func (ls *localServer) presenceDetectionHandler(next http.Handler) http.Handler reason = reasonHeader } - success, err := ls.presenceDetector.DetectPresence(reason, detectionIntervalDuration) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } + durationSinceLastDetection, err := ls.presenceDetector.DetectPresence(reason, detectionIntervalDuration) - if !success { - http.Error(w, "presence detection failed", http.StatusUnauthorized) - return - } + ls.slogger.Log(r.Context(), slog.LevelError, + "presence_detection", + "reason", reason, + "interval", detectionIntervalDuration, + "duration_since_last_detection", durationSinceLastDetection, + "err", err, + ) + + // if there was an error, we still want to return a 200 status code + // and send the request through + // allow the server to decide what to do based on last detection duration + w.Header().Add(kolideDurationSinceLastPresenceDetection, durationSinceLastDetection.String()) next.ServeHTTP(w, r) }) } diff --git a/ee/presencedetection/presencedetection.go b/ee/presencedetection/presencedetection.go index bc532bcbc..0cb2ab57f 100644 --- a/ee/presencedetection/presencedetection.go +++ b/ee/presencedetection/presencedetection.go @@ -6,28 +6,46 @@ import ( "time" ) +const DetectionFailedDurationValue = -1 * time.Second + type PresenceDetector struct { - lastDetectionUTC time.Time - mutext sync.Mutex + lastDetectionUTC time.Time + mutext sync.Mutex + hadOneSuccessfulDetection bool + // detectFunc that can be set for testing + detectFunc func(string) (bool, error) } -func (pd *PresenceDetector) DetectPresence(reason string, detectionInterval time.Duration) (bool, error) { +// DetectPresence checks if the user is present by detecting the presence of a user. +// It returns the duration since the last detection. +func (pd *PresenceDetector) DetectPresence(reason string, detectionInterval time.Duration) (time.Duration, error) { pd.mutext.Lock() defer pd.mutext.Unlock() - // Check if the last detection was within the detection interval - if time.Since(pd.lastDetectionUTC) < detectionInterval { - return true, nil + if pd.detectFunc == nil { + pd.detectFunc = Detect } - success, err := Detect(reason) - if err != nil { - return false, fmt.Errorf("detecting presence: %w", err) + // Check if the last detection was within the detection interval + if pd.hadOneSuccessfulDetection && time.Since(pd.lastDetectionUTC) < detectionInterval { + return time.Since(pd.lastDetectionUTC), nil } - if success { + success, err := pd.detectFunc(reason) + + switch { + case err != nil && pd.hadOneSuccessfulDetection: + return time.Since(pd.lastDetectionUTC), fmt.Errorf("detecting presence: %w", err) + + case err != nil: // error without initial successful detection + return DetectionFailedDurationValue, fmt.Errorf("detecting presence: %w", err) + + case success: pd.lastDetectionUTC = time.Now().UTC() - } + pd.hadOneSuccessfulDetection = true + return 0, nil - return success, nil + default: // failed detection without error, maybe not possible? + return time.Since(pd.lastDetectionUTC), fmt.Errorf("detection failed without error") + } } diff --git a/ee/presencedetection/presencedetection_test.go b/ee/presencedetection/presencedetection_test.go new file mode 100644 index 000000000..88c7e583a --- /dev/null +++ b/ee/presencedetection/presencedetection_test.go @@ -0,0 +1,86 @@ +package presencedetection + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPresenceDetector_DetectPresence(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + interval time.Duration + detectFunc func(string) (bool, error) + initialLastDetectionUTC time.Time + hadOneSuccessfulDetection bool + errAssertion assert.ErrorAssertionFunc + expectedLastDetectionDelta time.Duration + }{ + { + name: "first detection success", + detectFunc: func(string) (bool, error) { + return true, nil + }, + errAssertion: assert.NoError, + expectedLastDetectionDelta: 0, + }, + { + name: "detection within interval", + detectFunc: func(string) (bool, error) { + return false, errors.New("should not have called detectFunc, since within interval") + }, + errAssertion: assert.NoError, + initialLastDetectionUTC: time.Now().UTC(), + interval: time.Minute, + hadOneSuccessfulDetection: true, + }, + { + name: "error first detection", + detectFunc: func(string) (bool, error) { + return false, errors.New("error") + }, + errAssertion: assert.Error, + expectedLastDetectionDelta: -1, + }, + { + name: "error after first detection", + detectFunc: func(string) (bool, error) { + return false, errors.New("error") + }, + errAssertion: assert.Error, + initialLastDetectionUTC: time.Now().UTC(), + hadOneSuccessfulDetection: true, + }, + { + name: "detection failed without OS error", + detectFunc: func(string) (bool, error) { + return false, nil + }, + errAssertion: assert.Error, + initialLastDetectionUTC: time.Now().UTC(), + hadOneSuccessfulDetection: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pd := &PresenceDetector{ + detectFunc: tt.detectFunc, + lastDetectionUTC: tt.initialLastDetectionUTC, + hadOneSuccessfulDetection: tt.hadOneSuccessfulDetection, + } + + timeSinceLastDetection, err := pd.DetectPresence("this is a test", tt.interval) + tt.errAssertion(t, err) + + delta := timeSinceLastDetection - tt.expectedLastDetectionDelta + assert.LessOrEqual(t, delta, time.Second) + }) + } +}