Skip to content

Commit

Permalink
Add context and lock functionality to client interface (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
directionless authored Jun 27, 2023
1 parent d2e851b commit 6bcabfb
Show file tree
Hide file tree
Showing 10 changed files with 520 additions and 38 deletions.
191 changes: 161 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,76 @@ import (

"github.com/osquery/osquery-go/gen/osquery"
"github.com/osquery/osquery-go/transport"
"github.com/pkg/errors"

"github.com/apache/thrift/lib/go/thrift"
"github.com/pkg/errors"
)

const (
defaultWaitTime = 200 * time.Millisecond
defaultMaxWaitTime = 1 * time.Minute
)

// ExtensionManagerClient is a wrapper for the osquery Thrift extensions API.
type ExtensionManagerClient struct {
Client osquery.ExtensionManager
client osquery.ExtensionManager
transport thrift.TTransport

waitTime time.Duration
maxWaitTime time.Duration
lock *locker
}

type ClientOption func(*ExtensionManagerClient)

// WaitTime sets the default amount of wait time for the osquery socket to free up. You can override this on a per
// call basis by setting a context deadline
func DefaultWaitTime(d time.Duration) ClientOption {
return func(c *ExtensionManagerClient) {
c.waitTime = d
}
}

// MaxWaitTime is the maximum amount of time something is allowed to wait for the osquery socket. This takes precedence
// over the context deadline.
func MaxWaitTime(d time.Duration) ClientOption {
return func(c *ExtensionManagerClient) {
c.maxWaitTime = d
}
}

// NewClient creates a new client communicating to osquery over the socket at
// the provided path. If resolving the address or connecting to the socket
// fails, this function will error.
func NewClient(path string, timeout time.Duration) (*ExtensionManagerClient, error) {
trans, err := transport.Open(path, timeout)
if err != nil {
return nil, err
func NewClient(path string, socketOpenTimeout time.Duration, opts ...ClientOption) (*ExtensionManagerClient, error) {
c := &ExtensionManagerClient{
waitTime: defaultWaitTime,
maxWaitTime: defaultMaxWaitTime,
}

client := osquery.NewExtensionManagerClientFactory(
trans,
thrift.NewTBinaryProtocolFactoryDefault(),
)
for _, opt := range opts {
opt(c)
}

return &ExtensionManagerClient{client, trans}, nil
if c.waitTime > c.maxWaitTime {
return nil, errors.New("default wait time larger than max wait time")
}

c.lock = NewLocker(c.waitTime, c.maxWaitTime)

if c.client == nil {
trans, err := transport.Open(path, socketOpenTimeout)
if err != nil {
return nil, err
}

c.client = osquery.NewExtensionManagerClientFactory(
trans,
thrift.NewTBinaryProtocolFactoryDefault(),
)
}

return c, nil
}

// Close should be called to close the transport when use of the client is
Expand All @@ -42,48 +86,120 @@ func (c *ExtensionManagerClient) Close() {
}
}

// Ping requests metadata from the extension manager.
// Ping requests metadata from the extension manager, using a new background context
func (c *ExtensionManagerClient) Ping() (*osquery.ExtensionStatus, error) {
return c.Client.Ping(context.Background())
return c.PingContext(context.Background())
}

// Call requests a call to an extension (or core) registry plugin.
// PingContext requests metadata from the extension manager.
func (c *ExtensionManagerClient) PingContext(ctx context.Context) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Ping(ctx)
}

// Call requests a call to an extension (or core) registry plugin, using a new background context
func (c *ExtensionManagerClient) Call(registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
return c.Client.Call(context.Background(), registry, item, request)
return c.CallContext(context.Background(), registry, item, request)
}

// Extensions requests the list of active registered extensions.
// CallContext requests a call to an extension (or core) registry plugin.
func (c *ExtensionManagerClient) CallContext(ctx context.Context, registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Call(ctx, registry, item, request)
}

// Extensions requests the list of active registered extensions, using a new background context
func (c *ExtensionManagerClient) Extensions() (osquery.InternalExtensionList, error) {
return c.Client.Extensions(context.Background())
return c.ExtensionsContext(context.Background())
}

// ExtensionsContext requests the list of active registered extensions.
func (c *ExtensionManagerClient) ExtensionsContext(ctx context.Context) (osquery.InternalExtensionList, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Extensions(ctx)
}

// RegisterExtension registers the extension plugins with the osquery process.
// RegisterExtension registers the extension plugins with the osquery process, using a new background context
func (c *ExtensionManagerClient) RegisterExtension(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
return c.Client.RegisterExtension(context.Background(), info, registry)
return c.RegisterExtensionContext(context.Background(), info, registry)
}

// RegisterExtensionContext registers the extension plugins with the osquery process.
func (c *ExtensionManagerClient) RegisterExtensionContext(ctx context.Context, info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.RegisterExtension(ctx, info, registry)
}

// DeregisterExtension de-registers the extension plugins with the osquery process.
// DeregisterExtension de-registers the extension plugins with the osquery process, using a new background context
func (c *ExtensionManagerClient) DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
return c.Client.DeregisterExtension(context.Background(), uuid)
return c.DeregisterExtensionContext(context.Background(), uuid)
}

// DeregisterExtensionContext de-registers the extension plugins with the osquery process.
func (c *ExtensionManagerClient) DeregisterExtensionContext(ctx context.Context, uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.DeregisterExtension(ctx, uuid)
}

// Options requests the list of bootstrap or configuration options.
// Options requests the list of bootstrap or configuration options, using a new background context.
func (c *ExtensionManagerClient) Options() (osquery.InternalOptionList, error) {
return c.Client.Options(context.Background())
return c.OptionsContext(context.Background())
}

// OptionsContext requests the list of bootstrap or configuration options.
func (c *ExtensionManagerClient) OptionsContext(ctx context.Context) (osquery.InternalOptionList, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Options(ctx)
}

// Query requests a query to be run and returns the extension response.
// Query requests a query to be run and returns the extension
// response, using a new background context. Consider using the
// QueryRow or QueryRows helpers for a more friendly interface.
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
return c.QueryContext(context.Background(), sql)
}

// QueryContext requests a query to be run and returns the extension response.
// Consider using the QueryRow or QueryRows helpers for a more friendly
// interface.
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
return c.Client.Query(context.Background(), sql)
func (c *ExtensionManagerClient) QueryContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Query(ctx, sql)
}

// QueryRows is a helper that executes the requested query and returns the
// results. It handles checking both the transport level errors and the osquery
// internal errors by returning a normal Go error type.
func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, error) {
res, err := c.Query(sql)
return c.QueryRowsContext(context.Background(), sql)
}

// QueryRowsContext is a helper that executes the requested query and returns the
// results. It handles checking both the transport level errors and the osquery
// internal errors by returning a normal Go error type.
func (c *ExtensionManagerClient) QueryRowsContext(ctx context.Context, sql string) ([]map[string]string, error) {
res, err := c.QueryContext(ctx, sql)
if err != nil {
return nil, errors.Wrap(err, "transport error in query")
}
Expand All @@ -100,7 +216,13 @@ func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, err
// QueryRow behaves similarly to QueryRows, but it returns an error if the
// query does not return exactly one row.
func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error) {
res, err := c.QueryRows(sql)
return c.QueryRowContext(context.Background(), sql)
}

// QueryRowContext behaves similarly to QueryRows, but it returns an error if the
// query does not return exactly one row.
func (c *ExtensionManagerClient) QueryRowContext(ctx context.Context, sql string) (map[string]string, error) {
res, err := c.QueryRowsContext(ctx, sql)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +232,16 @@ func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error)
return res[0], nil
}

// GetQueryColumns requests the columns returned by the parsed query.
// GetQueryColumns requests the columns returned by the parsed query, using a new background context.
func (c *ExtensionManagerClient) GetQueryColumns(sql string) (*osquery.ExtensionResponse, error) {
return c.Client.GetQueryColumns(context.Background(), sql)
return c.GetQueryColumnsContext(context.Background(), sql)
}

// GetQueryColumnsContext requests the columns returned by the parsed query.
func (c *ExtensionManagerClient) GetQueryColumnsContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.GetQueryColumns(ctx, sql)
}
95 changes: 94 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@ package osquery
import (
"context"
"errors"
"fmt"
"os"
"sync"
"testing"
"time"

"github.com/osquery/osquery-go/gen/osquery"
"github.com/osquery/osquery-go/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestQueryRows(t *testing.T) {
t.Parallel()
mock := &mock.ExtensionManager{}
client := &ExtensionManagerClient{Client: mock}
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock))
require.NoError(t, err)

// Transport related error
mock.QueryFunc = func(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
Expand Down Expand Up @@ -77,3 +84,89 @@ func TestQueryRows(t *testing.T) {
row, err = client.QueryRow("select 1 union select 2")
assert.NotNil(t, err)
}

// TestLocking tests the the client correctly locks access to the osquery socket. Thrift only supports a single
// actor on the socket at a time, this means that in parallel go code, it's very easy to have messages get
// crossed and generate errors. This tests to ensure the locking works
func TestLocking(t *testing.T) {
t.Parallel()

sock := os.Getenv("OSQ_SOCKET")
if sock == "" {
t.Skip("no osquery socket specified")
}

osq, err := NewClient(sock, 5*time.Second)
require.NoError(t, err)

// The issue we're testing is about multithreaded access. Let's hammer on it!
wait := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wait.Add(1)
go func() {
defer wait.Done()

status, err := osq.Ping()
require.NoError(t, err, "call to Ping()")
if err != nil {
require.Equal(t, 0, status.Code, fmt.Errorf("ping returned %d: %s", status.Code, status.Message))
}
}()
}

wait.Wait()
}

func TestLockTimeouts(t *testing.T) {
t.Parallel()
mock := &mock.ExtensionManager{}
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock), DefaultWaitTime(100*time.Millisecond), DefaultWaitTime(5*time.Second))
require.NoError(t, err)

wait := sync.WaitGroup{}

errChan := make(chan error, 10)
for i := 0; i < 3; i++ {
wait.Add(1)
go func() {
defer wait.Done()

ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
defer cancel()

errChan <- client.SlowLocker(ctx, 75*time.Millisecond)
}()
}

wait.Wait()
close(errChan)

var successCount, errCount int
for err := range errChan {
if err == nil {
successCount += 1
} else {
errCount += 1
}
}

assert.Equal(t, 2, successCount, "expected success count")
assert.Equal(t, 1, errCount, "expected error count")
}

// WithOsqueryThriftClient sets the underlying thrift client. This can be used to set a mock
func WithOsqueryThriftClient(client osquery.ExtensionManager) ClientOption {
return func(c *ExtensionManagerClient) {
c.client = client
}
}

// SlowLocker attempts to emulate a slow sql routine, so we can test how lock timeouts work.
func (c *ExtensionManagerClient) SlowLocker(ctx context.Context, d time.Duration) error {
if err := c.lock.Lock(ctx); err != nil {
return err
}
defer c.lock.Unlock()
time.Sleep(d)
return nil
}
10 changes: 7 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ module github.com/osquery/osquery-go
require (
github.com/Microsoft/go-winio v0.4.9
github.com/apache/thrift v0.16.0
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pkg/errors v0.8.0
github.com/stretchr/testify v1.8.3
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
golang.org/x/sys v0.1.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.16
go 1.19
Loading

0 comments on commit 6bcabfb

Please sign in to comment.