Skip to content

Commit

Permalink
Strip down to just the working parts
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Sep 23, 2024
1 parent 262fce2 commit 64b0d22
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 120 deletions.
6 changes: 1 addition & 5 deletions ee/presencedetection/presencedetection_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ package presencedetection

import "errors"

func Register(credentialName string) error {
return errors.New("not implemented")
}

func Detect(reason string, credentialName string) (bool, error) {
func Detect(reason string) (bool, error) {
return false, errors.New("not implemented")
}
122 changes: 7 additions & 115 deletions ee/presencedetection/presencedetection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"unsafe"

ole "github.com/go-ole/go-ole"
"github.com/kolide/kit/ulid"
"github.com/saltosystems/winrt-go"
"github.com/saltosystems/winrt-go/windows/foundation"
"github.com/saltosystems/winrt-go/windows/storage/streams"
Expand Down Expand Up @@ -111,34 +112,23 @@ var roInitialize = sync.OnceFunc(func() {
ole.RoInitialize(1)
})

// Register creates a credential under the given name for the given user.
func Register(credentialName string) error {
// Detect prompts the user via Hello.
func Detect(reason string) (bool, error) {
roInitialize()

// Check to see if Hello is an option
isHelloSupported, err := isSupported()
if err != nil {
return fmt.Errorf("determining whether Hello is supported: %w", err)
return false, fmt.Errorf("determining whether Hello is supported: %w", err)
}
if !isHelloSupported {
return errors.New("presence detection via Hello is not supported")
return false, errors.New("presence detection via Hello is not supported")
}

// Create a credential that will be tied to the current user and this application
credentialName := ulid.New()
if err := register(credentialName); err != nil {
return fmt.Errorf("creating credential: %w", err)
}

return nil
}

// Detect prompts the user via Hello.
func Detect(_ string, credentialName string) (bool, error) {
roInitialize()

// Create a credential that will be tied to the current user and this application
if err := authenticate(credentialName); err != nil {
return false, fmt.Errorf("authenticating with credential: %w", err)
return false, fmt.Errorf("creating credential: %w", err)
}

return true, nil
Expand Down Expand Up @@ -296,104 +286,6 @@ func register(credentialName string) error {
return nil
}

// authenticate calls Windows.Security.Credentials.KeyCredentialManager.OpenAsync.
// It retrieves the key credential stored under `credentialName` for the given user and application.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.openasync?view=winrt-26100
func authenticate(credentialName string) error {
// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

credentialNameHString, err := ole.NewHString(credentialName)
if err != nil {
return fmt.Errorf("creating credential name hstring: %w", err)
}
defer ole.DeleteHString(credentialNameHString)

var openAsyncOperation *foundation.IAsyncOperation
openReturn, _, _ := syscall.SyscallN(
keyCredentialManager.VTable().OpenAsync,
0, // Because this is a static function, we don't pass in a reference to `this`
uintptr(unsafe.Pointer(&credentialNameHString)), // The name of the key credential to retrieve
uintptr(unsafe.Pointer(&openAsyncOperation)), // Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
)
if openReturn != 0 {
return fmt.Errorf("calling OpenAsync: %w", ole.NewError(openReturn))
}

// OpenAsync returns Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
iid := winrt.ParameterizedInstanceGUID(foundation.GUIDAsyncOperationCompletedHandler, keyCredentialRetrievalResultSignature)
statusChan := make(chan foundation.AsyncStatus)
handler := foundation.NewAsyncOperationCompletedHandler(ole.NewGUID(iid), func(instance *foundation.AsyncOperationCompletedHandler, asyncInfo *foundation.IAsyncOperation, asyncStatus foundation.AsyncStatus) {
statusChan <- asyncStatus
})
defer handler.Release()
openAsyncOperation.SetCompleted(handler)

select {
case operationStatus := <-statusChan:
if operationStatus != foundation.AsyncStatusCompleted {
return fmt.Errorf("OpenAsync operation did not complete: status %d", operationStatus)
}
case <-time.After(1 * time.Minute):
return errors.New("timed out waiting for OpenAsync operation to complete")
}

// Retrieve the results from the async operation
resPtr, err := openAsyncOperation.GetResults()
if err != nil {
return fmt.Errorf("getting results of OpenAsync: %w", err)
}

if uintptr(resPtr) == 0x0 {
return errors.New("no response to OpenAsync")
}

resultObj, err := (*ole.IUnknown)(resPtr).QueryInterface(keyCredentialRetrievalResultGuid)
if err != nil {
return fmt.Errorf("could not get KeyCredentialRetrievalResult from result of OpenAsync: %w", err)
}
defer resultObj.Release()
result := (*KeyCredentialRetrievalResult)(unsafe.Pointer(resultObj))

// Now, retrieve the KeyCredential from the KeyCredentialRetrievalResult
var credentialPointer unsafe.Pointer
getCredentialReturn, _, _ := syscall.SyscallN(
result.VTable().GetCredential,
uintptr(unsafe.Pointer(result)), // Since we're retrieving an object property, we need a reference to `this`
uintptr(unsafe.Pointer(&credentialPointer)),
)
if getCredentialReturn != 0 {
return fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn))
}

keyCredentialObj, err := (*ole.IUnknown)(credentialPointer).QueryInterface(keyCredentialGuid)
if err != nil {
return fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err)
}
defer keyCredentialObj.Release()

// For now, we retrieve but do not return/store the pubkey and attestation. In the future
// we may want to store these.
if _, err := getPubkey(keyCredentialObj); err != nil {
return fmt.Errorf("getting pubkey from credential: %w", err)
}
if _, err := getAttestation(keyCredentialObj); err != nil {
return fmt.Errorf("getting attestation from credential: %w", err)
}

return nil
}

// getPubkey calls Windows.Security.Credentials.KeyCredential.RetrievePubkey.
// It returns the pubkey for the given key credential.
// See https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.retrievepublickey?view=winrt-26100.
Expand Down

0 comments on commit 64b0d22

Please sign in to comment.