Skip to content

Commit

Permalink
Refactor and split up into Register + Detect
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Sep 23, 2024
1 parent 109d1ed commit 44e4a12
Showing 1 changed file with 159 additions and 66 deletions.
225 changes: 159 additions & 66 deletions ee/presencedetection/presencedetection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"unsafe"

ole "github.com/go-ole/go-ole"
"github.com/kolide/kit/ulid"
"github.com/saltosystems/winrt-go"
"github.com/saltosystems/winrt-go/windows/foundation"
"github.com/saltosystems/winrt-go/windows/storage/streams"
Expand Down Expand Up @@ -107,51 +106,95 @@ type KeyCredentialAttestationResultVTable struct {
GetStatus uintptr
}

// Detect prompts the user via Hello.
// TODO RM:
// * the syscalls panic easily; we will probably need to wrap this in a recovery routine
// * for readability, we should refactor individual calls into functions hanging off the appropriate structs above
func Detect(reason string) (bool, error) {
// Register creates a credential under the given name for the given user.
func Register(credentialName string) error {
if err := ole.RoInitialize(1); err != nil {
return false, fmt.Errorf("initializing: %w", err)
return fmt.Errorf("initializing: %w", err)
}

// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return false, fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return false, fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
return fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

// Check to see if Hello is an option
isHelloSupported, err := isSupported(keyCredentialManager)
if err != nil {
return false, fmt.Errorf("determining whether Hello is supported: %w", err)
return fmt.Errorf("determining whether Hello is supported: %w", err)
}
if !isHelloSupported {
return false, errors.New("Hello is not supported")
return errors.New("presence detection via Hello is not supported")
}

// Create a credential that will be tied to the current user and this application
credentialName := ulid.New()
_, keyCredentialObj, err := requestCreate(keyCredentialManager, credentialName)
keyCredentialObj, err := register(keyCredentialManager, credentialName)
defer func() {
if keyCredentialObj != nil {
keyCredentialObj.Release()
}
}()
if err != nil {
return false, fmt.Errorf("creating credential, %w", err)
return fmt.Errorf("creating credential: %w", err)
}

// For now, we retrieve but do not store the pubkey and attestation. In the future
// we may want to store these.
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))
if _, err := getAttestationAsync(credential); err != nil {
if _, err := getPubkey(credential); err != nil {
return fmt.Errorf("getting pubkey from credential: %w", err)
}
if _, err := getAttestation(credential); err != nil {
return fmt.Errorf("getting attestation from credential: %w", err)
}

return nil
}

// Detect prompts the user via Hello.
func Detect(_ string, credentialName string) (bool, error) {
if err := ole.RoInitialize(1); err != nil {
return false, fmt.Errorf("initializing: %w", err)
}

// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return false, fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return false, fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

// Create a credential that will be tied to the current user and this application
keyCredentialObj, err := authenticate(keyCredentialManager, credentialName)
defer func() {
if keyCredentialObj != nil {
keyCredentialObj.Release()
}
}()
if err != nil {
return false, fmt.Errorf("creating credential: %w", err)
}

// For now, we retrieve but do not store the pubkey and attestation. In the future
// we may want to store these.
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))
if _, err := getPubkey(credential); err != nil {
return false, fmt.Errorf("getting pubkey from credential: %w", err)
}
if _, err := getAttestation(credential); err != nil {
return false, fmt.Errorf("getting attestation from credential: %w", err)
}

Expand Down Expand Up @@ -198,13 +241,13 @@ func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) {
return uintptr(res) > 0, nil
}

// requestCreate calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync.
// register calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync.
// It creates a new key credential for the current user and application.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.requestcreateasync?view=winrt-26100
func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName string) ([]byte, *ole.IDispatch, error) {
func register(keyCredentialManager *KeyCredentialManager, credentialName string) (*ole.IDispatch, error) {

Check failure on line 247 in ee/presencedetection/presencedetection_windows.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

confusing-naming: Method 'register' differs only by capitalization to function 'Register' in the same source file (revive)

Check failure on line 247 in ee/presencedetection/presencedetection_windows.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

confusing-naming: Method 'register' differs only by capitalization to function 'Register' in the same source file (revive)
credentialNameHString, err := ole.NewHString(credentialName)
if err != nil {
return nil, nil, fmt.Errorf("creating credential name hstring: %w", err)
return nil, fmt.Errorf("creating credential name hstring: %w", err)
}
defer ole.DeleteHString(credentialNameHString)

Expand All @@ -217,7 +260,7 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st
uintptr(unsafe.Pointer(&requestCreateAsyncOperation)), // Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
)
if requestCreateReturn != 0 {
return nil, nil, fmt.Errorf("calling RequestCreateAsync: %w", ole.NewError(requestCreateReturn))
return nil, fmt.Errorf("calling RequestCreateAsync: %w", ole.NewError(requestCreateReturn))
}

// RequestCreateAsync returns Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
Expand All @@ -232,25 +275,25 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st
select {
case operationStatus := <-statusChan:
if operationStatus != foundation.AsyncStatusCompleted {
return nil, nil, fmt.Errorf("RequestCreateAsync operation did not complete: status %d", operationStatus)
return nil, fmt.Errorf("RequestCreateAsync operation did not complete: status %d", operationStatus)
}
case <-time.After(1 * time.Minute):
return nil, nil, errors.New("timed out waiting for RequestCreateAsync operation to complete")
return nil, errors.New("timed out waiting for RequestCreateAsync operation to complete")
}

// Retrieve the results from the async operation
resPtr, err := requestCreateAsyncOperation.GetResults()
if err != nil {
return nil, nil, fmt.Errorf("getting results of RequestCreateAsync: %w", err)
return nil, fmt.Errorf("getting results of RequestCreateAsync: %w", err)
}

if uintptr(resPtr) == 0x0 {
return nil, nil, errors.New("no response to RequestCreateAsync")
return nil, errors.New("no response to RequestCreateAsync")
}

resultObj, err := (*ole.IUnknown)(resPtr).QueryInterface(keyCredentialRetrievalResultGuid)
if err != nil {
return nil, nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of RequestCreateAsync: %w", err)
return nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of RequestCreateAsync: %w", err)
}
defer resultObj.Release()
result := (*KeyCredentialRetrievalResult)(unsafe.Pointer(resultObj))
Expand All @@ -263,55 +306,133 @@ func requestCreate(keyCredentialManager *KeyCredentialManager, credentialName st
uintptr(unsafe.Pointer(&credentialPointer)),
)
if getCredentialReturn != 0 {
return nil, nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn))
return nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn))
}

keyCredentialObj, err := (*ole.IUnknown)(credentialPointer).QueryInterface(keyCredentialGuid)
if err != nil {
return nil, nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err)
return nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err)
}
defer keyCredentialObj.Release()
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))

// All right, things are going swimmingly. Let's retrieve the public key.
return keyCredentialObj, nil
}

// authenticate calls Windows.Security.Credentials.KeyCredentialManager.OpenAsync.
// It retrieves the key credential stored under `credentialName` for the given user and application.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.openasync?view=winrt-26100
func authenticate(keyCredentialManager *KeyCredentialManager, credentialName string) (*ole.IDispatch, error) {
credentialNameHString, err := ole.NewHString(credentialName)
if err != nil {
return nil, fmt.Errorf("creating credential name hstring: %w", err)
}
defer ole.DeleteHString(credentialNameHString)

var openAsyncOperation *foundation.IAsyncOperation
openReturn, _, _ := syscall.SyscallN(
keyCredentialManager.VTable().OpenAsync,
0, // Because this is a static function, we don't pass in a reference to `this`
uintptr(unsafe.Pointer(&credentialNameHString)), // The name of the key credential to retrieve
uintptr(unsafe.Pointer(&openAsyncOperation)), // Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
)
if openReturn != 0 {
return nil, fmt.Errorf("calling OpenAsync: %w", ole.NewError(openReturn))
}

// OpenAsync returns Windows.Foundation.IAsyncOperation<KeyCredentialRetrievalResult>
iid := winrt.ParameterizedInstanceGUID(foundation.GUIDAsyncOperationCompletedHandler, keyCredentialRetrievalResultSignature)
statusChan := make(chan foundation.AsyncStatus)
handler := foundation.NewAsyncOperationCompletedHandler(ole.NewGUID(iid), func(instance *foundation.AsyncOperationCompletedHandler, asyncInfo *foundation.IAsyncOperation, asyncStatus foundation.AsyncStatus) {
statusChan <- asyncStatus
})
defer handler.Release()
openAsyncOperation.SetCompleted(handler)

select {
case operationStatus := <-statusChan:
if operationStatus != foundation.AsyncStatusCompleted {
return nil, fmt.Errorf("OpenAsync operation did not complete: status %d", operationStatus)
}
case <-time.After(1 * time.Minute):
return nil, errors.New("timed out waiting for OpenAsync operation to complete")
}

// Retrieve the results from the async operation
resPtr, err := openAsyncOperation.GetResults()
if err != nil {
return nil, fmt.Errorf("getting results of OpenAsync: %w", err)
}

if uintptr(resPtr) == 0x0 {
return nil, errors.New("no response to OpenAsync")
}

resultObj, err := (*ole.IUnknown)(resPtr).QueryInterface(keyCredentialRetrievalResultGuid)
if err != nil {
return nil, fmt.Errorf("could not get KeyCredentialRetrievalResult from result of OpenAsync: %w", err)
}
defer resultObj.Release()
result := (*KeyCredentialRetrievalResult)(unsafe.Pointer(resultObj))

// Now, retrieve the KeyCredential from the KeyCredentialRetrievalResult
var credentialPointer unsafe.Pointer
getCredentialReturn, _, _ := syscall.SyscallN(
result.VTable().GetCredential,
uintptr(unsafe.Pointer(result)), // Since we're retrieving an object property, we need a reference to `this`
uintptr(unsafe.Pointer(&credentialPointer)),
)
if getCredentialReturn != 0 {
return nil, fmt.Errorf("calling GetCredential on KeyCredentialRetrievalResult: %w", ole.NewError(getCredentialReturn))
}

keyCredentialObj, err := (*ole.IUnknown)(credentialPointer).QueryInterface(keyCredentialGuid)
if err != nil {
return nil, fmt.Errorf("could not get KeyCredential from KeyCredentialRetrievalResult: %w", err)
}

return keyCredentialObj, nil
}

// getPubkey calls Windows.Security.Credentials.KeyCredential.RetrievePubkey.
// It returns the pubkey for the given key credential.
// See https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.retrievepublickey?view=winrt-26100.
func getPubkey(credential *KeyCredential) ([]byte, error) {
var pubkeyBufferPointer unsafe.Pointer
retrievePubKeyReturn, _, _ := syscall.SyscallN(
credential.VTable().RetrievePublicKeyWithDefaultBlobType,
uintptr(unsafe.Pointer(credential)), // Not a static method, so we need a reference to `this`
uintptr(unsafe.Pointer(&pubkeyBufferPointer)),
)
if retrievePubKeyReturn != 0 {
return nil, nil, fmt.Errorf("calling RetrievePublicKey on KeyCredential: %w", ole.NewError(retrievePubKeyReturn))
return nil, fmt.Errorf("calling RetrievePublicKey on KeyCredential: %w", ole.NewError(retrievePubKeyReturn))
}

pubkeyBufferObj, err := (*ole.IUnknown)(pubkeyBufferPointer).QueryInterface(ole.NewGUID(streams.GUIDIBuffer))
if err != nil {
return nil, nil, fmt.Errorf("could not get buffer from result of RetrievePublicKey: %w", err)
return nil, fmt.Errorf("could not get buffer from result of RetrievePublicKey: %w", err)
}
defer pubkeyBufferObj.Release()
pubkeyBuffer := (*streams.IBuffer)(unsafe.Pointer(pubkeyBufferObj))

pubkeyBufferLen, err := pubkeyBuffer.GetLength()
if err != nil {
return nil, nil, fmt.Errorf("could not get length of pubkey buffer: %w", err)
return nil, fmt.Errorf("could not get length of pubkey buffer: %w", err)
}
pubkeyReader, err := streams.DataReaderFromBuffer(pubkeyBuffer)
if err != nil {
return nil, nil, fmt.Errorf("could not create data reader for pubkey buffer: %w", err)
return nil, fmt.Errorf("could not create data reader for pubkey buffer: %w", err)
}
pubkeyBytes, err := pubkeyReader.ReadBytes(pubkeyBufferLen)
if err != nil {
return nil, nil, fmt.Errorf("reading from pubkey buffer: %w", err)
return nil, fmt.Errorf("reading from pubkey buffer: %w", err)
}

return pubkeyBytes, keyCredentialObj, nil
return pubkeyBytes, nil
}

// getAttestationAsync calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync.
// getAttestation calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync.
// It gets an attestation for a key credential.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.getattestationasync?view=winrt-26100
func getAttestationAsync(credential *KeyCredential) ([]byte, error) {
// Now it's time to get the attestation. This is another async operation.
func getAttestation(credential *KeyCredential) ([]byte, error) {
var getAttestationAsyncOperation *foundation.IAsyncOperation
getAttestationReturn, _, _ := syscall.SyscallN(
credential.VTable().GetAttestationAsync,
Expand Down Expand Up @@ -392,31 +513,3 @@ func getAttestationAsync(credential *KeyCredential) ([]byte, error) {

return attestationBytes, nil
}

/*
// waitForAsyncOperation should allow us to abstract away the details of waiting for an async operation,
// but right now it only works for IsSupportedAsync; it results in an error 3 being returned from RequestCreateAsync.
// TODO RM -- fix.
func waitForAsyncOperation(signature string, timeout time.Duration, asyncOperation *foundation.IAsyncOperation) (unsafe.Pointer, error) {
statusChan := make(chan foundation.AsyncStatus)
iid := winrt.ParameterizedInstanceGUID(foundation.GUIDAsyncOperationCompletedHandler, signature)
handler := foundation.NewAsyncOperationCompletedHandler(ole.NewGUID(iid), func(instance *foundation.AsyncOperationCompletedHandler, asyncInfo *foundation.IAsyncOperation, asyncStatus foundation.AsyncStatus) {
statusChan <- asyncStatus
})
defer handler.Release()
asyncOperation.SetCompleted(handler)
select {
case operationStatus := <-statusChan:
if operationStatus != foundation.AsyncStatusCompleted {
return nil, fmt.Errorf("async operation did not complete: status %d", operationStatus)
}
case <-time.After(timeout):
return nil, errors.New("timed out waiting for operation to complete")
}
return asyncOperation.GetResults()
}
*/

0 comments on commit 44e4a12

Please sign in to comment.