diff --git a/ee/presencedetection/presencedetection_windows.go b/ee/presencedetection/presencedetection_windows.go index b864d6a61..9d8200d95 100644 --- a/ee/presencedetection/presencedetection_windows.go +++ b/ee/presencedetection/presencedetection_windows.go @@ -11,7 +11,6 @@ import ( "unsafe" ole "github.com/go-ole/go-ole" - "github.com/kolide/kit/ulid" "github.com/saltosystems/winrt-go" "github.com/saltosystems/winrt-go/windows/foundation" "github.com/saltosystems/winrt-go/windows/storage/streams" @@ -107,24 +106,21 @@ type KeyCredentialAttestationResultVTable struct { GetStatus uintptr } -// Detect prompts the user via Hello. -// TODO RM: -// * the syscalls panic easily; we will probably need to wrap this in a recovery routine -// * for readability, we should refactor individual calls into functions hanging off the appropriate structs above -func Detect(reason string) (bool, error) { +// Register creates a credential under the given name for the given user. +func Register(credentialName string) error { if err := ole.RoInitialize(1); err != nil { - return false, fmt.Errorf("initializing: %w", err) + return fmt.Errorf("initializing: %w", err) } // Get access to the KeyCredentialManager factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) if err != nil { - return false, fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) + return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) } defer factory.Release() managerObj, err := factory.QueryInterface(keyCredentialManagerGuid) if err != nil { - return false, fmt.Errorf("getting KeyCredentialManager from factory: %w", err) + return fmt.Errorf("getting KeyCredentialManager from factory: %w", err) } defer managerObj.Release() keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) @@ -132,26 +128,73 @@ func Detect(reason string) (bool, error) { // Check to see if Hello is an option isHelloSupported, err := isSupported(keyCredentialManager) if err != nil { - return false, fmt.Errorf("determining whether Hello is supported: %w", err) + return fmt.Errorf("determining whether Hello is supported: %w", err) } if !isHelloSupported { - return false, errors.New("Hello is not supported") + return errors.New("presence detection via Hello is not supported") } // Create a credential that will be tied to the current user and this application - credentialName := ulid.New() - _, keyCredentialObj, err := requestCreate(keyCredentialManager, credentialName) + keyCredentialObj, err := register(keyCredentialManager, credentialName) defer func() { if keyCredentialObj != nil { keyCredentialObj.Release() } }() if err != nil { - return false, fmt.Errorf("creating credential, %w", err) + return fmt.Errorf("creating credential: %w", err) } + // For now, we retrieve but do not store the pubkey and attestation. In the future + // we may want to store these. credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) - if _, err := getAttestationAsync(credential); err != nil { + if _, err := getPubkey(credential); err != nil { + return fmt.Errorf("getting pubkey from credential: %w", err) + } + if _, err := getAttestation(credential); err != nil { + return fmt.Errorf("getting attestation from credential: %w", err) + } + + return nil +} + +// Detect prompts the user via Hello. +func Detect(_ string, credentialName string) (bool, error) { + if err := ole.RoInitialize(1); err != nil { + return false, fmt.Errorf("initializing: %w", err) + } + + // Get access to the KeyCredentialManager + factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) + if err != nil { + return false, fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) + } + defer factory.Release() + managerObj, err := factory.QueryInterface(keyCredentialManagerGuid) + if err != nil { + return false, fmt.Errorf("getting KeyCredentialManager from factory: %w", err) + } + defer managerObj.Release() + keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) + + // Create a credential that will be tied to the current user and this application + keyCredentialObj, err := authenticate(keyCredentialManager, credentialName) + defer func() { + if keyCredentialObj != nil { + keyCredentialObj.Release() + } + }() + if err != nil { + return false, fmt.Errorf("creating credential: %w", err) + } + + // For now, we retrieve but do not store the pubkey and attestation. In the future + // we may want to store these. + credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) + if _, err := getPubkey(credential); err != nil { + return false, fmt.Errorf("getting pubkey from credential: %w", err) + } + if _, err := getAttestation(credential); err != nil { return false, fmt.Errorf("getting attestation from credential: %w", err) } @@ -198,13 +241,13 @@ func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) { return uintptr(res) > 0, nil } -// requestCreate calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync. +// register calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync. // It creates a new key credential for the current user and application. // See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.requestcreateasync?view=winrt-26100 -func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName string) ([]byte, *ole.IDispatch, error) { +func register(keyCredentialManager *KeyCredentialManager, credentialName string) (*ole.IDispatch, error) { credentialNameHString, err := ole.NewHString(credentialName) if err != nil { - return nil, nil, fmt.Errorf("creating credential name hstring: %w", err) + return nil, fmt.Errorf("creating credential name hstring: %w", err) } defer ole.DeleteHString(credentialNameHString) @@ -217,7 +260,7 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st uintptr(unsafe.Pointer(&requestCreateAsyncOperation)), // Windows.Foundation.IAsyncOperation ) if requestCreateReturn != 0 { - return nil, nil, fmt.Errorf("calling RequestCreateAsync: %w", ole.NewError(requestCreateReturn)) + return nil, fmt.Errorf("calling RequestCreateAsync: %w", ole.NewError(requestCreateReturn)) } // RequestCreateAsync returns Windows.Foundation.IAsyncOperation @@ -232,25 +275,25 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st select { case operationStatus := <-statusChan: if operationStatus != foundation.AsyncStatusCompleted { - return nil, nil, fmt.Errorf("RequestCreateAsync operation did not complete: status %d", operationStatus) + return nil, fmt.Errorf("RequestCreateAsync operation did not complete: status %d", operationStatus) } case <-time.After(1 * time.Minute): - return nil, nil, errors.New("timed out waiting for RequestCreateAsync operation to complete") + return nil, errors.New("timed out waiting for RequestCreateAsync operation to complete") } // Retrieve the results from the async operation resPtr, err := requestCreateAsyncOperation.GetResults() if err != nil { - return nil, nil, fmt.Errorf("getting results of RequestCreateAsync: %w", err) + return nil, fmt.Errorf("getting results of RequestCreateAsync: %w", err) } if uintptr(resPtr) == 0x0 { - return nil, nil, errors.New("no response to RequestCreateAsync") + return nil, errors.New("no response to RequestCreateAsync") } resultObj, err := (*ole.IUnknown)(resPtr).QueryInterface(keyCredentialRetrievalResultGuid) if err != nil { - return nil, nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of RequestCreateAsync: %w", err) + return nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of RequestCreateAsync: %w", err) } defer resultObj.Release() result := (*KeyCredentialRetrievalResult)(unsafe.Pointer(resultObj)) @@ -263,17 +306,96 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st uintptr(unsafe.Pointer(&credentialPointer)), ) if getCredentialReturn != 0 { - return nil, nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn)) + return nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn)) } keyCredentialObj, err := (*ole.IUnknown)(credentialPointer).QueryInterface(keyCredentialGuid) if err != nil { - return nil, nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err) + return nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err) } - defer keyCredentialObj.Release() - credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) - // All right, things are going swimmingly. Let's retrieve the public key. + return keyCredentialObj, nil +} + +// authenticate calls Windows.Security.Credentials.KeyCredentialManager.OpenAsync. +// It retrieves the key credential stored under `credentialName` for the given user and application. +// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.openasync?view=winrt-26100 +func authenticate(keyCredentialManager *KeyCredentialManager, credentialName string) (*ole.IDispatch, error) { + credentialNameHString, err := ole.NewHString(credentialName) + if err != nil { + return nil, fmt.Errorf("creating credential name hstring: %w", err) + } + defer ole.DeleteHString(credentialNameHString) + + var openAsyncOperation *foundation.IAsyncOperation + openReturn, _, _ := syscall.SyscallN( + keyCredentialManager.VTable().OpenAsync, + 0, // Because this is a static function, we don't pass in a reference to `this` + uintptr(unsafe.Pointer(&credentialNameHString)), // The name of the key credential to retrieve + uintptr(unsafe.Pointer(&openAsyncOperation)), // Windows.Foundation.IAsyncOperation + ) + if openReturn != 0 { + return nil, fmt.Errorf("calling OpenAsync: %w", ole.NewError(openReturn)) + } + + // OpenAsync returns Windows.Foundation.IAsyncOperation + iid := winrt.ParameterizedInstanceGUID(foundation.GUIDAsyncOperationCompletedHandler, keyCredentialRetrievalResultSignature) + statusChan := make(chan foundation.AsyncStatus) + handler := foundation.NewAsyncOperationCompletedHandler(ole.NewGUID(iid), func(instance *foundation.AsyncOperationCompletedHandler, asyncInfo *foundation.IAsyncOperation, asyncStatus foundation.AsyncStatus) { + statusChan <- asyncStatus + }) + defer handler.Release() + openAsyncOperation.SetCompleted(handler) + + select { + case operationStatus := <-statusChan: + if operationStatus != foundation.AsyncStatusCompleted { + return nil, fmt.Errorf("OpenAsync operation did not complete: status %d", operationStatus) + } + case <-time.After(1 * time.Minute): + return nil, errors.New("timed out waiting for OpenAsync operation to complete") + } + + // Retrieve the results from the async operation + resPtr, err := openAsyncOperation.GetResults() + if err != nil { + return nil, fmt.Errorf("getting results of OpenAsync: %w", err) + } + + if uintptr(resPtr) == 0x0 { + return nil, errors.New("no response to OpenAsync") + } + + resultObj, err := (*ole.IUnknown)(resPtr).QueryInterface(keyCredentialRetrievalResultGuid) + if err != nil { + return nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of OpenAsync: %w", err) + } + defer resultObj.Release() + result := (*KeyCredentialRetrievalResult)(unsafe.Pointer(resultObj)) + + // Now, retrieve the KeyCredential from the KeyCredentialRetrievalResult + var credentialPointer unsafe.Pointer + getCredentialReturn, _, _ := syscall.SyscallN( + result.VTable().GetCredential, + uintptr(unsafe.Pointer(result)), // Since we're retrieving an object property, we need a reference to `this` + uintptr(unsafe.Pointer(&credentialPointer)), + ) + if getCredentialReturn != 0 { + return nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn)) + } + + keyCredentialObj, err := (*ole.IUnknown)(credentialPointer).QueryInterface(keyCredentialGuid) + if err != nil { + return nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err) + } + + return keyCredentialObj, nil +} + +// getPubkey calls Windows.Security.Credentials.KeyCredential.RetrievePubkey. +// It returns the pubkey for the given key credential. +// See https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.retrievepublickey?view=winrt-26100. +func getPubkey(credential *KeyCredential) ([]byte, error) { var pubkeyBufferPointer unsafe.Pointer retrievePubKeyReturn, _, _ := syscall.SyscallN( credential.VTable().RetrievePublicKeyWithDefaultBlobType, @@ -281,37 +403,36 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st uintptr(unsafe.Pointer(&pubkeyBufferPointer)), ) if retrievePubKeyReturn != 0 { - return nil, nil, fmt.Errorf("calling RetrievePublicKey on KeyCredential: %w", ole.NewError(retrievePubKeyReturn)) + return nil, fmt.Errorf("calling RetrievePublicKey on KeyCredential: %w", ole.NewError(retrievePubKeyReturn)) } pubkeyBufferObj, err := (*ole.IUnknown)(pubkeyBufferPointer).QueryInterface(ole.NewGUID(streams.GUIDIBuffer)) if err != nil { - return nil, nil, fmt.Errorf("could not get buffer from result of RetrievePublicKey: %w", err) + return nil, fmt.Errorf("could not get buffer from result of RetrievePublicKey: %w", err) } defer pubkeyBufferObj.Release() pubkeyBuffer := (*streams.IBuffer)(unsafe.Pointer(pubkeyBufferObj)) pubkeyBufferLen, err := pubkeyBuffer.GetLength() if err != nil { - return nil, nil, fmt.Errorf("could not get length of pubkey buffer: %w", err) + return nil, fmt.Errorf("could not get length of pubkey buffer: %w", err) } pubkeyReader, err := streams.DataReaderFromBuffer(pubkeyBuffer) if err != nil { - return nil, nil, fmt.Errorf("could not create data reader for pubkey buffer: %w", err) + return nil, fmt.Errorf("could not create data reader for pubkey buffer: %w", err) } pubkeyBytes, err := pubkeyReader.ReadBytes(pubkeyBufferLen) if err != nil { - return nil, nil, fmt.Errorf("reading from pubkey buffer: %w", err) + return nil, fmt.Errorf("reading from pubkey buffer: %w", err) } - return pubkeyBytes, keyCredentialObj, nil + return pubkeyBytes, nil } -// getAttestationAsync calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync. +// getAttestation calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync. // It gets an attestation for a key credential. // See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.getattestationasync?view=winrt-26100 -func getAttestationAsync(credential *KeyCredential) ([]byte, error) { - // Now it's time to get the attestation. This is another async operation. +func getAttestation(credential *KeyCredential) ([]byte, error) { var getAttestationAsyncOperation *foundation.IAsyncOperation getAttestationReturn, _, _ := syscall.SyscallN( credential.VTable().GetAttestationAsync, @@ -392,31 +513,3 @@ func getAttestationAsync(credential *KeyCredential) ([]byte, error) { return attestationBytes, nil } - -/* -// waitForAsyncOperation should allow us to abstract away the details of waiting for an async operation, -// but right now it only works for IsSupportedAsync; it results in an error 3 being returned from RequestCreateAsync. -// TODO RM -- fix. -func waitForAsyncOperation(signature string, timeout time.Duration, asyncOperation *foundation.IAsyncOperation) (unsafe.Pointer, error) { - statusChan := make(chan foundation.AsyncStatus) - - iid := winrt.ParameterizedInstanceGUID(foundation.GUIDAsyncOperationCompletedHandler, signature) - handler := foundation.NewAsyncOperationCompletedHandler(ole.NewGUID(iid), func(instance *foundation.AsyncOperationCompletedHandler, asyncInfo *foundation.IAsyncOperation, asyncStatus foundation.AsyncStatus) { - statusChan <- asyncStatus - }) - defer handler.Release() - - asyncOperation.SetCompleted(handler) - - select { - case operationStatus := <-statusChan: - if operationStatus != foundation.AsyncStatusCompleted { - return nil, fmt.Errorf("async operation did not complete: status %d", operationStatus) - } - case <-time.After(timeout): - return nil, errors.New("timed out waiting for operation to complete") - } - - return asyncOperation.GetResults() -} -*/