Skip to content

Commit

Permalink
Merge branch 'main' into control-protocol-chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrehilbert authored Dec 19, 2023
2 parents 34faf07 + dcc6493 commit ac063fb
Show file tree
Hide file tree
Showing 19 changed files with 146 additions and 120 deletions.
4 changes: 3 additions & 1 deletion magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -1887,7 +1887,9 @@ func createTestRunner(matrix bool, singleTest string, goTestFlags string, batche
}

} else if stackProvisionerMode == ess.ProvisionerServerless {
stackProvisioner, err = ess.NewServerlessProvisioner(provisionCfg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
stackProvisioner, err = ess.NewServerlessProvisioner(ctx, provisionCfg)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/testing/ess/serverless.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ func (srv *ServerlessClient) DeploymentIsReady(ctx context.Context) (bool, error
}

// DeleteDeployment deletes the deployment
func (srv *ServerlessClient) DeleteDeployment() error {
func (srv *ServerlessClient) DeleteDeployment(ctx context.Context) error {
endpoint := fmt.Sprintf("%s/api/v1/serverless/projects/%s/%s", serverlessURL, srv.proj.Type, srv.proj.ID)
req, err := http.NewRequestWithContext(context.Background(), "DELETE", endpoint, nil)
req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil)
if err != nil {
return fmt.Errorf("error creating HTTP request: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/testing/ess/serverless_provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ type ServerlessRegions struct {
}

// NewServerlessProvisioner creates a new StackProvisioner instance for serverless
func NewServerlessProvisioner(cfg ProvisionerConfig) (runner.StackProvisioner, error) {
func NewServerlessProvisioner(ctx context.Context, cfg ProvisionerConfig) (runner.StackProvisioner, error) {
prov := &ServerlessProvisioner{
cfg: cfg,
log: &defaultLogger{wrapped: logp.L()},
}
err := prov.CheckCloudRegion()
err := prov.CheckCloudRegion(ctx)
if err != nil {
return nil, fmt.Errorf("error checking region setting: %w", err)
}
Expand Down Expand Up @@ -178,7 +178,7 @@ func (prov *ServerlessProvisioner) Delete(ctx context.Context, stack runner.Stac
client.proj.Credentials.Password = stack.Password

prov.log.Logf("Destroying serverless stack %s [stack_id: %s, deployment_id: %s]", stack.Version, stack.ID, deploymentID)
err = client.DeleteDeployment()
err = client.DeleteDeployment(ctx)
if err != nil {
return fmt.Errorf("error removing serverless stack %s [stack_id: %s, deployment_id: %s]: %w", stack.Version, stack.ID, deploymentID, err)
}
Expand All @@ -188,10 +188,10 @@ func (prov *ServerlessProvisioner) Delete(ctx context.Context, stack runner.Stac
// CheckCloudRegion checks to see if the provided region is valid for the serverless
// if we have an invalid region, overwrite with a valid one.
// The "normal" and serverless ESS APIs have different regions, hence why we need this.
func (prov *ServerlessProvisioner) CheckCloudRegion() error {
func (prov *ServerlessProvisioner) CheckCloudRegion(ctx context.Context) error {
urlPath := fmt.Sprintf("%s/api/v1/serverless/regions", serverlessURL)

httpHandler, err := http.NewRequestWithContext(context.Background(), "GET", urlPath, nil)
httpHandler, err := http.NewRequestWithContext(ctx, "GET", urlPath, nil)
if err != nil {
return fmt.Errorf("error creating new httpRequest: %w", err)
}
Expand Down
16 changes: 10 additions & 6 deletions pkg/testing/ess/serverless_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import (
)

func TestProvisionGetRegions(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
defer cancel()

_ = logp.DevelopmentSetup()
key, found, err := GetESSAPIKey()
if !found {
Expand All @@ -29,13 +32,16 @@ func TestProvisionGetRegions(t *testing.T) {
cfg: cfg,
log: &defaultLogger{wrapped: logp.L()},
}
err = prov.CheckCloudRegion()
err = prov.CheckCloudRegion(ctx)
require.NoError(t, err)
require.NotEqual(t, "bad-region-ID", prov.cfg.Region)

}

func TestStackProvisioner(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

_ = logp.DevelopmentSetup()
key, found, err := GetESSAPIKey()
if !found {
Expand All @@ -45,12 +51,10 @@ func TestStackProvisioner(t *testing.T) {
require.True(t, found)

cfg := ProvisionerConfig{Region: "aws-eu-west-1", APIKey: key}
provClient, err := NewServerlessProvisioner(cfg)
provClient, err := NewServerlessProvisioner(ctx, cfg)
require.NoError(t, err)
request := runner.StackRequest{ID: "stack-test-one", Version: "8.9.0"}

ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
stack, err := provClient.Create(ctx, request)
require.NoError(t, err)
t.Logf("got results:")
Expand Down Expand Up @@ -78,7 +82,7 @@ func TestStartServerless(t *testing.T) {
key,
&defaultLogger{wrapped: logp.L()})

ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
defer cancel()

req := ServerlessRequest{Name: "ingest-e2e-test", RegionID: "aws-eu-west-1"}
Expand All @@ -95,6 +99,6 @@ func TestStartServerless(t *testing.T) {
t.Logf("got endpoints: %#v", clientHandle.proj.Endpoints)
t.Logf("got auth: %#v", clientHandle.proj.Credentials)

err = clientHandle.DeleteDeployment()
err = clientHandle.DeleteDeployment(ctx)
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion pkg/testing/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ func (f *Fixture) Run(ctx context.Context, states ...State) error {
}
case state := <-stateCh:
if smInstance != nil {
cfg, cont, err := smInstance.next(state)
cfg, cont, err := smInstance.next(ctx, state)
if err != nil {
killProc()
return fmt.Errorf("state management failed with unexpected error: %w", err)
Expand Down
27 changes: 17 additions & 10 deletions pkg/testing/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package testing

import (
"context"
"errors"
"fmt"

Expand Down Expand Up @@ -153,10 +154,12 @@ type State struct {

// Before is called once when this state is the next state trying to be resolved.
// This is called before the configuration is sent to the Elastic Agent if `Configuration` is set.
Before func() error
// Before is passed the context that was passed to fixture.Run().
Before func(ctx context.Context) error

// After is called once after this state has been resolved and the next state is going to be tried.
After func() error
// After is passed the context that was passed to fixture.Run().
After func(ctx context.Context) error
}

// Validate ensures correctness of state definition.
Expand Down Expand Up @@ -208,7 +211,7 @@ func newStateMachine(states []State) (*stateMachine, error) {
}, nil
}

func (sm *stateMachine) next(agentState *client.AgentState) (string, bool, error) {
func (sm *stateMachine) next(ctx context.Context, agentState *client.AgentState) (string, bool, error) {
if sm.current >= len(sm.states) {
// already made it to the end, should be stopped
return "", false, nil
Expand All @@ -224,7 +227,7 @@ func (sm *stateMachine) next(agentState *client.AgentState) (string, bool, error
}
if reached {
if state.After != nil {
if err := state.After(); err != nil {
if err := state.After(ctx); err != nil {
return "", false, fmt.Errorf("failed to perform After on state %d: %w", sm.current, err)
}
}
Expand All @@ -235,7 +238,7 @@ func (sm *stateMachine) next(agentState *client.AgentState) (string, bool, error
}
next := sm.states[sm.current]
if next.Before != nil {
if err := next.Before(); err != nil {
if err := next.Before(ctx); err != nil {
return "", false, fmt.Errorf("failed to perform Before on state %d: %w", sm.current, err)
}
}
Expand All @@ -244,7 +247,7 @@ func (sm *stateMachine) next(agentState *client.AgentState) (string, bool, error
return next.Configure, true, nil
}
// no configuration on this state; so we can determine if this next state has already been reached as well
return sm.next(agentState)
return sm.next(ctx, agentState)
}
return "", true, nil
}
Expand Down Expand Up @@ -283,7 +286,9 @@ func stateComponentsReached(components map[string]ComponentState, agentComponent
return false
}
found := make(map[string]bool)
for _, agentComp := range agentComponents {
for i := range agentComponents {
// Index the array to avoid aliasing a temporary loop value when taking the address below.
agentComp := agentComponents[i]
state, ok := components[agentComp.ID]
if !ok {
if strict {
Expand All @@ -298,7 +303,7 @@ func stateComponentsReached(components map[string]ComponentState, agentComponent
return false
}
}
for compID, _ := range components {
for compID := range components {
_, ok := found[compID]
if !ok {
// was not found
Expand Down Expand Up @@ -334,7 +339,9 @@ func stateComponentUnitsReached(units map[ComponentUnitKey]ComponentUnitState, c
return false
}
found := make(map[ComponentUnitKey]bool)
for _, compUnit := range compUnits {
for i := range compUnits {
// Index the array to avoid aliasing a temporary loop value when taking the address below.
compUnit := compUnits[i]
key := ComponentUnitKey{UnitType: compUnit.UnitType, UnitID: compUnit.UnitID}
state, ok := units[key]
if !ok {
Expand All @@ -350,7 +357,7 @@ func stateComponentUnitsReached(units map[ComponentUnitKey]ComponentUnitState, c
return false
}
}
for key, _ := range units {
for key := range units {
_, ok := found[key]
if !ok {
// was not found
Expand Down
23 changes: 15 additions & 8 deletions pkg/testing/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package testing

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -789,6 +790,9 @@ func TestStateMachine(t *testing.T) {

for _, scenario := range scenarios {
t.Run(scenario.Name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

m, err := newStateMachine(scenario.States)
if scenario.Err != nil {
require.Error(t, err)
Expand All @@ -797,7 +801,7 @@ func TestStateMachine(t *testing.T) {
require.NoError(t, err)
}
for _, nextCall := range scenario.AgentStates {
cfg, cont, err := m.next(nextCall.AgentState)
cfg, cont, err := m.next(ctx, nextCall.AgentState)
if nextCall.Err != nil {
require.Error(t, err)
require.Equal(t, nextCall.Err.Error(), err.Error())
Expand All @@ -821,32 +825,35 @@ func TestStateMachine_Before_After(t *testing.T) {
{
Configure: "my config",
AgentState: NewClientState(client.Configuring),
Before: func() error {
Before: func(ctx context.Context) error {
firstBefore = true
return nil
},
After: func() error {
After: func(ctx context.Context) error {
firstAfter = true
return nil
},
},
{
AgentState: NewClientState(client.Healthy),
Before: func() error {
Before: func(ctx context.Context) error {
secondBefore = true
return nil
},
After: func() error {
After: func(ctx context.Context) error {
secondAfter = true
return nil
},
},
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

m, err := newStateMachine(states)
require.NoError(t, err)

cfg, cont, err := m.next(&client.AgentState{
cfg, cont, err := m.next(ctx, &client.AgentState{
State: client.Configuring,
})
require.NoError(t, err)
Expand All @@ -855,7 +862,7 @@ func TestStateMachine_Before_After(t *testing.T) {
require.True(t, firstBefore)
require.False(t, firstAfter)

cfg, cont, err = m.next(&client.AgentState{
cfg, cont, err = m.next(ctx, &client.AgentState{
State: client.Configuring,
})
require.NoError(t, err)
Expand All @@ -865,7 +872,7 @@ func TestStateMachine_Before_After(t *testing.T) {
require.True(t, secondBefore)
require.False(t, secondAfter)

cfg, cont, err = m.next(&client.AgentState{
cfg, cont, err = m.next(ctx, &client.AgentState{
State: client.Healthy,
})
require.NoError(t, err)
Expand Down
9 changes: 5 additions & 4 deletions pkg/testing/tools/check/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ import (
// ConnectedToFleet checks if the agent defined in the fixture is connected to
// Fleet Server. It uses assert.Eventually and if it fails the last error will
// be printed. It returns if the agent is connected to Fleet Server or not.
func ConnectedToFleet(t *testing.T, fixture *integrationtest.Fixture, timeout time.Duration) bool {
func ConnectedToFleet(ctx context.Context, t *testing.T, fixture *integrationtest.Fixture, timeout time.Duration) bool {
t.Helper()

var err error
var agentStatus integrationtest.AgentStatusOutput
assertFn := func() bool {
agentStatus, err = fixture.ExecStatus(context.Background())
agentStatus, err = fixture.ExecStatus(ctx)
return agentStatus.FleetState == int(cproto.State_HEALTHY)
}

Expand All @@ -45,12 +45,13 @@ func ConnectedToFleet(t *testing.T, fixture *integrationtest.Fixture, timeout ti
// FleetAgentStatus returns a niladic function that returns true if the agent
// has reached expectedStatus; false otherwise. The returned function is intended
// for use with assert.Eventually or require.Eventually.
func FleetAgentStatus(t *testing.T,
func FleetAgentStatus(ctx context.Context,
t *testing.T,
client *kibana.Client,
policyID,
expectedStatus string) func() bool {
return func() bool {
currentStatus, err := fleettools.GetAgentStatus(client, policyID)
currentStatus, err := fleettools.GetAgentStatus(ctx, client, policyID)
if err != nil {
t.Errorf("unable to determine agent status: %s", err.Error())
return false
Expand Down
Loading

0 comments on commit ac063fb

Please sign in to comment.