diff --git a/api/resource_groups.go b/api/resource_groups.go index 35b15aa9a..c06b0add7 100644 --- a/api/resource_groups.go +++ b/api/resource_groups.go @@ -22,6 +22,7 @@ import ( _ "embed" "encoding/json" "fmt" + "strings" "time" "github.com/pkg/errors" @@ -95,14 +96,52 @@ func (svc *ResourceGroupsService) List() (response ResourceGroupsResponse, err e return rawResponse, err } + err = sanitizeFieldsInRawResponseList(&rawResponse, &response) + if err != nil { + return rawResponse, err + } + return rawResponse, nil } +func sanitizeFieldsInRawResponse(rawResponse *ResourceGroupResponse, response interface{}) error { + // update filters keys to match the query template + updateFiltersKeys(&rawResponse.Data) + + j, err := json.Marshal(rawResponse) + if err != nil { + return err + } + + return json.Unmarshal(j, &response) +} + +func sanitizeFieldsInRawResponseList(rawResponse *ResourceGroupsResponse, response interface{}) error { + for i := range rawResponse.Data { + // update filters keys to match the query template + updateFiltersKeys(&rawResponse.Data[i]) + } + + j, err := json.Marshal(rawResponse) + if err != nil { + return err + } + + return json.Unmarshal(j, &response) +} + func (svc *ResourceGroupsService) Create(group ResourceGroupData) ( response ResourceGroupResponse, err error, ) { - err = svc.create(group, &response) + var rawResponse ResourceGroupResponse + err = svc.create(group, &rawResponse) + if err != nil { + return + } + + err = sanitizeFieldsInRawResponse(&rawResponse, &response) + return } @@ -117,14 +156,54 @@ func (svc *ResourceGroupsService) Update(data *ResourceGroupData) ( guid := data.ID() data.ResetResourceGUID() - err = svc.update(guid, data, &response) + var rawResponse ResourceGroupResponse + err = svc.update(guid, data, &rawResponse) + if err != nil { return } + err = sanitizeFieldsInRawResponse(&rawResponse, &response) + return } +func collectFilterNames(children []*RGChild, filterNames map[string]string) { + for _, child := range children { + if child.FilterName != "" { + normalizedKey := strings.ReplaceAll(strings.ToLower(child.FilterName), "_", "") + filterNames[normalizedKey] = child.FilterName + } + if len(child.Children) > 0 { + collectFilterNames(child.Children, filterNames) + } + } +} + +/* +updateFiltersKeys updates the keys in the Filters map of ResourceGroupData to ensure they match the filter names +defined in the nested children of the query expression. This is necessary because JSON decoding/encoding can +convert keys to camel case, causing mismatches. The function normalizes the keys by removing underscores and +converting them to lower case, then compares them with the filter names. If a mismatch is found, the key is +updated to the value in RGExpression.Children +*/ +func updateFiltersKeys(data *ResourceGroupData) { + filterNames := make(map[string]string) + collectFilterNames(data.Query.Expression.Children, filterNames) + + updatedFilters := make(map[string]*RGFilter) + for key, value := range data.Query.Filters { + normalizedKey := strings.ReplaceAll(strings.ToLower(key), "_", "") + if _, exists := filterNames[normalizedKey]; exists { + updatedFilters[filterNames[normalizedKey]] = value + } else { + updatedFilters[key] = value + } + } + + data.Query.Filters = updatedFilters +} + func (group *ResourceGroupData) ResetResourceGUID() { group.ResourceGroupGuid = "" group.UpdatedBy = "" @@ -149,20 +228,17 @@ func (svc *ResourceGroupsService) Delete(guid string) error { func (svc *ResourceGroupsService) Get(guid string, response interface{}) error { var rawResponse ResourceGroupResponse + err := svc.get(guid, &rawResponse) if err != nil { return err } - j, err := json.Marshal(rawResponse) + err = sanitizeFieldsInRawResponse(&rawResponse, response) if err != nil { return err } - err = json.Unmarshal(j, &response) - if err != nil { - return err - } return nil } diff --git a/api/resource_groups_test.go b/api/resource_groups_test.go index a87a23c6d..ceaadffe0 100644 --- a/api/resource_groups_test.go +++ b/api/resource_groups_test.go @@ -19,6 +19,7 @@ package api_test import ( + "encoding/json" "fmt" "net/http" "strings" @@ -132,6 +133,113 @@ func TestResourceGroupGet(t *testing.T) { }) } +func TestResourceGroupsGetCorrectlyParsersFilterNames(t *testing.T) { + var ( + queryJson = ` + { + "expression": { + "children": [ + { + "filterName": "filter_account" + }, + { + "filterName": "filter1" + }, + { + "filterName": "filter2" + }, + { + "children": [ + { + "filterName": "team_Account" + } + ], + "operator": "OR" + } + + ], + "operator": "AND" + }, + "filters": { + "filter1": { + "field": "Resource Tag", + "key": "Hostname", + "operation": "INCLUDES", + "values": [ + "*" + ] + }, + "filter2": { + "field": "Region", + "operation": "STARTS_WITH", + "values": [ + "ap-south" + ] + }, + "filter_account": { + "field": "Account", + "operation": "EQUALS", + "values": [ + "123456789012" + ] + }, + "team_Account": { + "field": "Account", + "operation": "EQUALS", + "values": [ + "123456789012" + ] + } + } + } + ` + resourceGUID = intgguid.New() + vanillaType = "VANILLA" + apiPath = fmt.Sprintf("ResourceGroups/%s", resourceGUID) + vanillaGroup = singleVanillaResourceGroup(resourceGUID, vanillaType, queryJson) + fakeServer = lacework.MockServer() + ) + + fakeServer.MockToken("TOKEN") + defer fakeServer.Close() + + fakeServer.MockAPI(apiPath, + func(w http.ResponseWriter, r *http.Request) { + if assert.Equal(t, "GET", r.Method, "Get() should be a GET method") { + fmt.Fprintf(w, generateResourceGroupResponse(vanillaGroup)) + } + }, + ) + + c, err := api.NewClient("test", + api.WithToken("TOKEN"), + api.WithURL(fakeServer.URL()), + ) + + assert.Nil(t, err) + + t.Run("when resource groups GET is called. Filter keys are correctly parsed", func(t *testing.T) { + var response api.ResourceGroupResponse + err := c.V2.ResourceGroups.Get(resourceGUID, &response) + assert.Nil(t, err) + if assert.NotNil(t, response) { + assert.Equal(t, resourceGUID, response.Data.ResourceGroupGuid) + assert.Equal(t, "group_name", response.Data.Name) + assert.Equal(t, "VANILLA", response.Data.Type) + // assert that the filter names in queryjson matach RGQuery + var expectedQuery api.RGQuery + err = json.Unmarshal([]byte(queryJson), &expectedQuery) + assert.Nil(t, err) + + assert.NotNil(t, response.Data.Query.Filters["filter_account"]) + assert.Equal(t, expectedQuery.Filters["filter_account"], response.Data.Query.Filters["filter_account"]) + + assert.NotNil(t, response.Data.Query.Filters["team_Account"]) + assert.Equal(t, expectedQuery.Filters["team_Account"], response.Data.Query.Filters["team_Account"]) + } + }) +} + func TestResourceGroupsDelete(t *testing.T) { var ( resourceGUID = intgguid.New() diff --git a/cli/cmd/resource_groups.go b/cli/cmd/resource_groups.go index 576ed5224..236bdc066 100644 --- a/cli/cmd/resource_groups.go +++ b/cli/cmd/resource_groups.go @@ -208,24 +208,7 @@ func promptCreateResourceGroup() error { return err } - switch group { - case "AWS": - return createResourceGroup("AWS") - case "AZURE": - return createResourceGroup("AZURE") - case "GCP": - return createResourceGroup("GCP") - case "CONTAINER": - return createResourceGroup("CONTAINER") - case "MACHINE": - return createResourceGroup("MACHINE") - case "OCI": - return createResourceGroup("OCI") - case "KUBERNETES": - return createResourceGroup("KUBERNETES") - default: - return errors.New("unknown resource group type") - } + return createResourceGroup(group) } func init() {