From 95ed4192968903e19b02027a72d006a68d777bb4 Mon Sep 17 00:00:00 2001 From: marco Date: Tue, 24 Sep 2024 14:52:17 +0200 Subject: [PATCH 1/2] context propagation: pass ctx to UpdateScenario() --- cmd/crowdsec-cli/clicapi/capi.go | 2 +- cmd/crowdsec/lapiclient.go | 2 +- pkg/apiclient/auth_jwt.go | 7 +++++-- pkg/apiclient/config.go | 3 ++- pkg/apiserver/apic.go | 10 ++++------ pkg/apiserver/apic_test.go | 6 ++++-- 6 files changed, 17 insertions(+), 13 deletions(-) diff --git a/cmd/crowdsec-cli/clicapi/capi.go b/cmd/crowdsec-cli/clicapi/capi.go index 24c3ba054a9..cba66f11104 100644 --- a/cmd/crowdsec-cli/clicapi/capi.go +++ b/cmd/crowdsec-cli/clicapi/capi.go @@ -170,7 +170,7 @@ func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login // I don't believe papi is neede to check enrollement // PapiURL: papiURL, VersionPrefix: "v3", - UpdateScenario: func() ([]string, error) { + UpdateScenario: func(_ context.Context) ([]string, error) { return itemsForAPI, nil }, }) diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go index 4556306825c..eed517f9df9 100644 --- a/cmd/crowdsec/lapiclient.go +++ b/cmd/crowdsec/lapiclient.go @@ -36,7 +36,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub. URL: apiURL, PapiURL: papiURL, VersionPrefix: "v1", - UpdateScenario: func() ([]string, error) { + UpdateScenario: func(_ context.Context) ([]string, error) { return itemsForAPI, nil }, }) diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index b202e382842..193486ff065 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -2,6 +2,7 @@ package apiclient import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -30,15 +31,17 @@ type JWTTransport struct { // Transport is the underlying HTTP transport to use when making requests. // It will default to http.DefaultTransport if nil. Transport http.RoundTripper - UpdateScenario func() ([]string, error) + UpdateScenario func(context.Context) ([]string, error) refreshTokenMutex sync.Mutex } func (t *JWTTransport) refreshJwtToken() error { var err error + ctx := context.TODO() + if t.UpdateScenario != nil { - t.Scenarios, err = t.UpdateScenario() + t.Scenarios, err = t.UpdateScenario(ctx) if err != nil { return fmt.Errorf("can't update scenario list: %w", err) } diff --git a/pkg/apiclient/config.go b/pkg/apiclient/config.go index b08452e74e0..29a8acf185e 100644 --- a/pkg/apiclient/config.go +++ b/pkg/apiclient/config.go @@ -1,6 +1,7 @@ package apiclient import ( + "context" "net/url" "github.com/go-openapi/strfmt" @@ -15,5 +16,5 @@ type Config struct { VersionPrefix string UserAgent string RegistrationToken string - UpdateScenario func() ([]string, error) + UpdateScenario func(context.Context) ([]string, error) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 9b56fef6549..c8768e71b0a 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -82,11 +82,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { return ret } -func (a *apic) FetchScenariosListFromDB() ([]string, error) { +func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) { scenarios := make([]string, 0) - ctx := context.TODO() - machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) @@ -214,7 +212,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) } - ret.scenarioList, err = ret.FetchScenariosListFromDB() + ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx) if err != nil { return nil, fmt.Errorf("while fetching scenarios from db: %w", err) } @@ -234,7 +232,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient // The watcher will be authenticated by the RoundTripper the first time it will call CAPI // Explicit authentication will provoke a useless supplementary call to CAPI - scenarios, err := ret.FetchScenariosListFromDB() + scenarios, err := ret.FetchScenariosListFromDB(ctx) if err != nil { return ret, fmt.Errorf("get scenario in db: %w", err) } @@ -944,7 +942,7 @@ func (a *apic) Pull(ctx context.Context) error { toldOnce := false for { - scenario, err := a.FetchScenariosListFromDB() + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 3bb158acf35..a215edb2fbd 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -143,6 +143,8 @@ func TestAPICCAPIPullIsOld(t *testing.T) { } func TestAPICFetchScenariosListFromDB(t *testing.T) { + ctx := context.Background() + tests := []struct { name string machineIDsWithScenarios map[string]string @@ -174,10 +176,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { SetPassword(testPassword.String()). SetIpAddress("1.2.3.4"). SetScenarios(scenarios). - ExecX(context.Background()) + ExecX(ctx) } - scenarios, err := api.FetchScenariosListFromDB() + scenarios, err := api.FetchScenariosListFromDB(ctx) require.NoError(t, err) for machineID := range tc.machineIDsWithScenarios { From db2c0bcea56d02acd1852c7a856a60b2fda6e90e Mon Sep 17 00:00:00 2001 From: marco Date: Tue, 24 Sep 2024 15:20:14 +0200 Subject: [PATCH 2/2] context propagation: SendMetrics, SendUsageMetrics, plugin config --- pkg/apiserver/apic_metrics.go | 12 ++++-------- pkg/apiserver/apic_metrics_test.go | 12 +++++++----- pkg/apiserver/apiserver.go | 4 ++-- pkg/csplugin/notifier.go | 4 +--- pkg/protobufs/plugin_interface.go | 4 ++-- 5 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 16b2328dbe9..3d9e7b28a79 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -251,11 +251,9 @@ func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { // Metrics are sent at start, then at the randomized metricsIntervalFirst, // then at regular metricsInterval. If a change is detected in the list // of machines, the next metrics are sent immediately. -func (a *apic) SendMetrics(stop chan (bool)) { +func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") - ctx := context.TODO() - // verify the list of machines every interval const checkInt = 20 * time.Second @@ -321,7 +319,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { if metrics != nil { log.Info("capi metrics: sending") - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) + _, _, err = a.apiClient.Metrics.Add(ctx, metrics) if err != nil { log.Errorf("capi metrics: failed: %s", err) } @@ -339,11 +337,9 @@ func (a *apic) SendMetrics(stop chan (bool)) { } } -func (a *apic) SendUsageMetrics() { +func (a *apic) SendUsageMetrics(ctx context.Context) { defer trace.CatchPanic("lapi/usageMetricsToAPIC") - ctx := context.TODO() - firstRun := true log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval) @@ -368,7 +364,7 @@ func (a *apic) SendUsageMetrics() { continue } - _, resp, err := a.apiClient.UsageMetrics.Add(context.Background(), metrics) + _, resp, err := a.apiClient.UsageMetrics.Add(ctx, metrics) if err != nil { log.Errorf("unable to send usage metrics: %s", err) diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index 78b16f9c8b7..13a24668f26 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -14,6 +14,8 @@ import ( ) func TestAPICSendMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string duration time.Duration @@ -34,7 +36,7 @@ func TestAPICSendMetrics(t *testing.T) { metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) { - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) api.dbClient.Ent.Machine.Create(). SetMachineId("1234"). SetPassword(testPassword.String()). @@ -42,16 +44,16 @@ func TestAPICSendMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) api.dbClient.Ent.Bouncer.Create(). SetIPAddress("1.2.3.6"). SetName("someBouncer"). SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) }, }, } @@ -86,7 +88,7 @@ func TestAPICSendMetrics(t *testing.T) { httpmock.ZeroCallCounters() - go api.SendMetrics(stop) + go api.SendMetrics(ctx, stop) time.Sleep(tc.duration) stop <- true diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 6b5d6803be9..2b2b453348a 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -357,12 +357,12 @@ func (s *APIServer) initAPIC(ctx context.Context) { } s.apic.metricsTomb.Go(func() error { - s.apic.SendMetrics(make(chan bool)) + s.apic.SendMetrics(ctx, make(chan bool)) return nil }) s.apic.metricsTomb.Go(func() error { - s.apic.SendUsageMetrics() + s.apic.SendUsageMetrics(ctx) return nil }) } diff --git a/pkg/csplugin/notifier.go b/pkg/csplugin/notifier.go index 2b5d57fbcff..ed4a4cc4149 100644 --- a/pkg/csplugin/notifier.go +++ b/pkg/csplugin/notifier.go @@ -40,9 +40,7 @@ func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notific } func (m *GRPCClient) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { - _, err := m.client.Configure( - context.Background(), config, - ) + _, err := m.client.Configure(ctx, config) return &protobufs.Empty{}, err } diff --git a/pkg/protobufs/plugin_interface.go b/pkg/protobufs/plugin_interface.go index fc89b2fa009..baa76c8941c 100644 --- a/pkg/protobufs/plugin_interface.go +++ b/pkg/protobufs/plugin_interface.go @@ -24,12 +24,12 @@ type NotifierPlugin struct { type GRPCClient struct{ client NotifierClient } func (m *GRPCClient) Notify(ctx context.Context, notification *Notification) (*Empty, error) { - _, err := m.client.Notify(context.Background(), notification) + _, err := m.client.Notify(ctx, notification) return &Empty{}, err } func (m *GRPCClient) Configure(ctx context.Context, config *Config) (*Empty, error) { - _, err := m.client.Configure(context.Background(), config) + _, err := m.client.Configure(ctx, config) return &Empty{}, err }