Skip to content

Commit

Permalink
context propagation: pass ctx to UpdateScenario()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 24, 2024
1 parent 3945a99 commit 95ed419
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clicapi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login
// I don't believe papi is neede to check enrollement
// PapiURL: papiURL,
VersionPrefix: "v3",
UpdateScenario: func() ([]string, error) {
UpdateScenario: func(_ context.Context) ([]string, error) {
return itemsForAPI, nil
},
})
Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec/lapiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {
UpdateScenario: func(_ context.Context) ([]string, error) {
return itemsForAPI, nil
},
})
Expand Down
7 changes: 5 additions & 2 deletions pkg/apiclient/auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apiclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -30,15 +31,17 @@ type JWTTransport struct {
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
UpdateScenario func() ([]string, error)
UpdateScenario func(context.Context) ([]string, error)
refreshTokenMutex sync.Mutex
}

func (t *JWTTransport) refreshJwtToken() error {
var err error

ctx := context.TODO()

if t.UpdateScenario != nil {
t.Scenarios, err = t.UpdateScenario()
t.Scenarios, err = t.UpdateScenario(ctx)

Check warning on line 44 in pkg/apiclient/auth_jwt.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_jwt.go#L44

Added line #L44 was not covered by tests
if err != nil {
return fmt.Errorf("can't update scenario list: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/apiclient/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiclient

import (
"context"
"net/url"

"github.com/go-openapi/strfmt"
Expand All @@ -15,5 +16,5 @@ type Config struct {
VersionPrefix string
UserAgent string
RegistrationToken string
UpdateScenario func() ([]string, error)
UpdateScenario func(context.Context) ([]string, error)
}
10 changes: 4 additions & 6 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
return ret
}

func (a *apic) FetchScenariosListFromDB() ([]string, error) {
func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) {
scenarios := make([]string, 0)

ctx := context.TODO()

machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, fmt.Errorf("while listing machines: %w", err)
Expand Down Expand Up @@ -214,7 +212,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
}

ret.scenarioList, err = ret.FetchScenariosListFromDB()
ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx)
if err != nil {
return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
}
Expand All @@ -234,7 +232,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient

// The watcher will be authenticated by the RoundTripper the first time it will call CAPI
// Explicit authentication will provoke a useless supplementary call to CAPI
scenarios, err := ret.FetchScenariosListFromDB()
scenarios, err := ret.FetchScenariosListFromDB(ctx)
if err != nil {
return ret, fmt.Errorf("get scenario in db: %w", err)
}
Expand Down Expand Up @@ -944,7 +942,7 @@ func (a *apic) Pull(ctx context.Context) error {
toldOnce := false

for {
scenario, err := a.FetchScenariosListFromDB()
scenario, err := a.FetchScenariosListFromDB(ctx)
if err != nil {
log.Errorf("unable to fetch scenarios from db: %s", err)
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
}

func TestAPICFetchScenariosListFromDB(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
machineIDsWithScenarios map[string]string
Expand Down Expand Up @@ -174,10 +176,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
SetPassword(testPassword.String()).
SetIpAddress("1.2.3.4").
SetScenarios(scenarios).
ExecX(context.Background())
ExecX(ctx)
}

scenarios, err := api.FetchScenariosListFromDB()
scenarios, err := api.FetchScenariosListFromDB(ctx)
require.NoError(t, err)

for machineID := range tc.machineIDsWithScenarios {
Expand Down

0 comments on commit 95ed419

Please sign in to comment.