diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index e7418a5c458..dc67815d8dc 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -16,7 +16,7 @@ var CTIUrlSuffix = "/v2/smoke/" var CTIApiKey = "" // this is set for non-recoverable errors, such as 403 when querying API or empty API key -var CTIApiEnabled = true +var CTIApiEnabled = false // when hitting quotas or auth errors, we temporarily disable the API var CTIBackOffUntil time.Time @@ -25,9 +25,9 @@ var CTIBackOffDuration time.Duration = 5 * time.Minute var ctiClient *cticlient.CrowdsecCTIClient func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.Level) error { - if Key == nil { - CTIApiEnabled = false - return fmt.Errorf("CTI API key not set, CTI will not be available") + if Key == nil || *Key == "" { + log.Warningf("CTI API key not set or empty, CTI will not be available") + return cticlient.ErrDisabled } CTIApiKey = *Key if Size == nil { @@ -38,7 +38,6 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L TTL = new(time.Duration) *TTL = 5 * time.Minute } - //dedicated logger clog := log.New() if err := types.ConfigureLogger(clog); err != nil { return errors.Wrap(err, "while configuring datasource logger") @@ -52,6 +51,7 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L subLogger := clog.WithFields(customLog) CrowdsecCTIInitCache(*Size, *TTL) ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey(CTIApiKey), cticlient.WithLogger(subLogger)) + CTIApiEnabled = true return nil } @@ -60,7 +60,7 @@ func ShutdownCrowdsecCTI() { CTICache.Purge() } CTIApiKey = "" - CTIApiEnabled = true + CTIApiEnabled = false } // Cache for responses @@ -74,20 +74,13 @@ func CrowdsecCTIInitCache(size int, ttl time.Duration) { // func CrowdsecCTI(ip string) (*cticlient.SmokeItem, error) { func CrowdsecCTI(params ...any) (any, error) { - ip := params[0].(string) + var ip string if !CTIApiEnabled { - ctiClient.Logger.Warningf("Crowdsec CTI API is disabled, please check your configuration") return &cticlient.SmokeItem{}, cticlient.ErrDisabled } - - if CTIApiKey == "" { - ctiClient.Logger.Warningf("CrowdsecCTI : no key provided, skipping") - return &cticlient.SmokeItem{}, cticlient.ErrDisabled - } - - if ctiClient == nil { - ctiClient.Logger.Warningf("CrowdsecCTI: no client, skipping") - return &cticlient.SmokeItem{}, cticlient.ErrDisabled + var ok bool + if ip, ok = params[0].(string); !ok { + return &cticlient.SmokeItem{}, fmt.Errorf("invalid type for ip : %T", params[0]) } if val, err := CTICache.Get(ip); err == nil && val != nil { diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index c8d1c92fd40..84cd3347be5 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -106,6 +106,16 @@ func smokeHandler(req *http.Request) *http.Response { } } +func TestNillClient(t *testing.T) { + defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); err != cticlient.ErrDisabled { + t.Fatalf("failed to init CTI : %s", err) + } + item, err := CrowdsecCTI("1.2.3.4") + assert.Equal(t, err, cticlient.ErrDisabled) + assert.Equal(t, item, &cticlient.SmokeItem{}) +} + func TestInvalidAuth(t *testing.T) { defer ShutdownCrowdsecCTI() if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil { @@ -135,7 +145,7 @@ func TestInvalidAuth(t *testing.T) { func TestNoKey(t *testing.T) { defer ShutdownCrowdsecCTI() err := InitCrowdsecCTI(nil, nil, nil, nil) - assert.ErrorContains(t, err, "CTI API key not set") + assert.ErrorIs(t, err, cticlient.ErrDisabled) //Replace the client created by InitCrowdsecCTI with one that uses a custom transport ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler),