From f93b869f261dac1819a3eea7426cd6d56aa695d1 Mon Sep 17 00:00:00 2001
From: jacobshandling <61553566+jacobshandling@users.noreply.github.com>
Date: Thu, 23 Jan 2025 12:38:57 -0800
Subject: [PATCH] Update label membership by host IDs directly (#25687)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## For #25261
### [Demo
video](https://drive.google.com/file/d/1ZFcrizkZ6zNODnTXjRC1f-Oeght5zOP4/view?usp=sharing)
- [x] Changes file added for user-visible changes in `changes/`,
- [x] Added/updated automated tests
- [ ] A detailed QA plan exists on the associated ticket (if it isn't
there, work with the product group's QA engineer to add it)
- [x] Manual QA for all new/changed functionality
---------
Co-authored-by: Jacob Shandling
---
...25261-identical-hostnames-label-membership | 2 +
.../LiveQuery/TargetsInput/TargetsInput.tsx | 4 +-
.../ManualLabelForm/ManualLabelForm.tsx | 4 +-
server/datastore/mysql/labels.go | 58 ++++++
server/datastore/mysql/labels_test.go | 165 ++++++++++++++++++
server/fleet/datastore.go | 4 +
server/mock/datastore_mock.go | 12 ++
server/service/integration_core_test.go | 19 ++
server/service/labels.go | 36 +---
9 files changed, 272 insertions(+), 32 deletions(-)
create mode 100644 changes/25261-identical-hostnames-label-membership
diff --git a/changes/25261-identical-hostnames-label-membership b/changes/25261-identical-hostnames-label-membership
new file mode 100644
index 000000000000..ddd5f61d253e
--- /dev/null
+++ b/changes/25261-identical-hostnames-label-membership
@@ -0,0 +1,2 @@
+- Fixed a bug where adding or removing a host with an identical name to/from a label caused the
+ same action to be performed on other host(s) with the same name as well.
diff --git a/frontend/components/LiveQuery/TargetsInput/TargetsInput.tsx b/frontend/components/LiveQuery/TargetsInput/TargetsInput.tsx
index 70c5106227c8..c38ba8d5e4c4 100644
--- a/frontend/components/LiveQuery/TargetsInput/TargetsInput.tsx
+++ b/frontend/components/LiveQuery/TargetsInput/TargetsInput.tsx
@@ -12,7 +12,6 @@ import TableContainer from "components/TableContainer";
import { ITargestInputHostTableConfig } from "./TargetsInputHostsTableConfig";
interface ITargetsInputProps {
- tabIndex?: number;
searchText: string;
searchResults: IHost[];
isTargetsLoading: boolean;
@@ -35,7 +34,6 @@ const baseClass = "targets-input";
const DEFAULT_LABEL = "Target specific hosts";
const TargetsInput = ({
- tabIndex,
searchText,
searchResults,
isTargetsLoading,
@@ -52,7 +50,7 @@ const TargetsInput = ({
}: ITargetsInputProps): JSX.Element => {
const dropdownRef = useRef(null);
const dropdownHosts =
- searchResults && pullAllBy(searchResults, targetedHosts, "display_name");
+ searchResults && pullAllBy(searchResults, targetedHosts, "id");
const [isActiveSearch, setIsActiveSearch] = useState(false);
diff --git a/frontend/pages/labels/components/ManualLabelForm/ManualLabelForm.tsx b/frontend/pages/labels/components/ManualLabelForm/ManualLabelForm.tsx
index c4f52af7b634..e8d75b98c675 100644
--- a/frontend/pages/labels/components/ManualLabelForm/ManualLabelForm.tsx
+++ b/frontend/pages/labels/components/ManualLabelForm/ManualLabelForm.tsx
@@ -71,7 +71,7 @@ const ManualLabelForm = ({
}, [debounceSearch, searchQuery]);
const {
- data: hostTargets,
+ data: searchResults,
isLoading: isLoadingSearchResults,
isError: isErrorSearchResults,
} = useQuery(
@@ -139,7 +139,7 @@ const ManualLabelForm = ({
selectedHostsTableConifg={selectedHostsTableConfig}
isTargetsLoading={isLoadingSearchResults || isDebouncing}
hasFetchError={isErrorSearchResults}
- searchResults={hostTargets ?? []}
+ searchResults={searchResults ?? []}
targetedHosts={targetedHosts}
setSearchText={onChangeSearchQuery}
handleRowSelect={onHostSelect}
diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go
index cb98aa67f4b0..b420ba0917dd 100644
--- a/server/datastore/mysql/labels.go
+++ b/server/datastore/mysql/labels.go
@@ -123,6 +123,64 @@ func batchHostnames(hostnames []string) [][]string {
return batches
}
+func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIds []uint) (err error) {
+ err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
+ // delete all label membership
+ sql := `
+ DELETE FROM label_membership WHERE label_id = ?
+ `
+ _, err := tx.ExecContext(ctx, sql, labelID)
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "clear membership for ID")
+ }
+
+ if len(hostIds) == 0 {
+ return nil
+ }
+
+ // Split hostIds into batches to avoid parameter limit in MySQL.
+ for _, hostIds := range batchHostIds(hostIds) {
+ // Use ignore because duplicate hostIds could appear in
+ // different batches and would result in duplicate key errors.
+ values := []interface{}{}
+ placeholders := []string{}
+
+ for _, hostID := range hostIds {
+ values = append(values, labelID, hostID)
+ placeholders = append(placeholders, "(?, ?)")
+ }
+
+ // Build the final SQL query with the dynamically generated placeholders
+ sql := `
+INSERT IGNORE INTO label_membership (label_id, host_id)
+VALUES ` + strings.Join(placeholders, ", ")
+ sql, args, err := sqlx.In(sql, values...)
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "build membership IN statement")
+ }
+ _, err = tx.ExecContext(ctx, sql, args...)
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "execute membership INSERT")
+ }
+ }
+ return nil
+ })
+
+ return ctxerr.Wrap(ctx, err, "UpdateLabelMembershipByHostIDs transaction")
+}
+
+func batchHostIds(hostIds []uint) [][]uint {
+ // same functionality as `batchHostnames`, but for host IDs
+ const batchSize = 50000 // Large, but well under the undocumented limit
+ batches := make([][]uint, 0, (len(hostIds)+batchSize-1)/batchSize)
+
+ for batchSize < len(hostIds) {
+ hostIds, batches = hostIds[batchSize:], append(batches, hostIds[0:batchSize:batchSize])
+ }
+ batches = append(batches, hostIds)
+ return batches
+}
+
func (ds *Datastore) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) {
var specs []*fleet.LabelSpec
// Get basic specs
diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go
index 6aa9250b5478..08deb44ba96e 100644
--- a/server/datastore/mysql/labels_test.go
+++ b/server/datastore/mysql/labels_test.go
@@ -41,6 +41,27 @@ func TestBatchHostnamesLarge(t *testing.T) {
assert.Equal(t, large[200000:230000], batched[4])
}
+func TestBatchHostIdsSmall(t *testing.T) {
+ small := []uint{1, 2, 3}
+ batched := batchHostIds(small)
+ require.Equal(t, 1, len(batched))
+ assert.Equal(t, small, batched[0])
+}
+
+func TestBatchHostIdsLarge(t *testing.T) {
+ large := []uint{}
+ for i := 0; i < 230000; i++ {
+ large = append(large, uint(i)) //nolint:gosec // dismiss G115
+ }
+ batched := batchHostIds(large)
+ require.Equal(t, 5, len(batched))
+ assert.Equal(t, large[:50000], batched[0])
+ assert.Equal(t, large[50000:100000], batched[1])
+ assert.Equal(t, large[100000:150000], batched[2])
+ assert.Equal(t, large[150000:200000], batched[3])
+ assert.Equal(t, large[200000:230000], batched[4])
+}
+
func TestLabels(t *testing.T) {
ds := CreateMySQLDS(t)
@@ -60,6 +81,7 @@ func TestLabels(t *testing.T) {
{"ChangeDetails", testLabelsChangeDetails},
{"GetSpec", testLabelsGetSpec},
{"ApplySpecsRoundtrip", testLabelsApplySpecsRoundtrip},
+ {"UpdateLabelMembershipByHostIDs", testUpdateLabelMembershipByHostIDs},
{"IDsByName", testLabelsIDsByName},
{"ByName", testLabelsByName},
{"Save", testLabelsSave},
@@ -1744,3 +1766,146 @@ func labelIDFromName(t *testing.T, ds fleet.Datastore, name string) uint {
}
return 0
}
+
+func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) {
+ ctx := context.Background()
+ host1, err := ds.NewHost(ctx, &fleet.Host{
+ OsqueryHostID: ptr.String("1"),
+ NodeKey: ptr.String("1"),
+ UUID: "1",
+ Hostname: "foo.local",
+ Platform: "darwin",
+ })
+ require.NoError(t, err)
+ host2, err := ds.NewHost(ctx, &fleet.Host{
+ OsqueryHostID: ptr.String("2"),
+ NodeKey: ptr.String("2"),
+ UUID: "2",
+ Hostname: "bar.local",
+ Platform: "windows",
+ })
+ require.NoError(t, err)
+ // hosts 2 and 3 have the same hostname
+ host3, err := ds.NewHost(ctx, &fleet.Host{
+ OsqueryHostID: ptr.String("3"),
+ NodeKey: ptr.String("3"),
+ UUID: "3",
+ Hostname: "bar.local",
+ Platform: "windows",
+ })
+ require.NoError(t, err)
+
+ label1, err := ds.NewLabel(ctx, &fleet.Label{
+ Name: "label1",
+ Query: "",
+ LabelType: fleet.LabelTypeRegular,
+ LabelMembershipType: fleet.LabelMembershipTypeManual,
+ })
+ require.NoError(t, err)
+
+ // add hosts 1 and 2 to the label
+ err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host2.ID})
+ require.NoError(t, err)
+
+ // expect hosts 1 and 2 to be in the label, but not 3
+
+ label, err := ds.GetLabelSpec(ctx, label1.Name)
+ require.NoError(t, err)
+ // label.Hosts contains hostnames
+ require.Len(t, label.Hosts, 2)
+ require.Equal(t, host1.Hostname, label.Hosts[0])
+ require.Equal(t, host2.Hostname, label.Hosts[1])
+
+ labels, err := ds.ListLabelsForHost(ctx, host1.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host2.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host3.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ // modify the label to contain hosts 1 and 3, confirm
+ err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host3.ID})
+ require.NoError(t, err)
+
+ labels, err = ds.ListLabelsForHost(ctx, host1.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host2.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ labels, err = ds.ListLabelsForHost(ctx, host3.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ // modify the label to contain hosts 2 and 3, confirm
+ err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host2.ID, host3.ID})
+ require.NoError(t, err)
+
+ labels, err = ds.ListLabelsForHost(ctx, host1.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ labels, err = ds.ListLabelsForHost(ctx, host2.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host3.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ // modify the label to contain no hosts, confirm
+ err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{})
+ require.NoError(t, err)
+
+ labels, err = ds.ListLabelsForHost(ctx, host1.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ labels, err = ds.ListLabelsForHost(ctx, host2.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ labels, err = ds.ListLabelsForHost(ctx, host3.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 0)
+
+ // modify the label to contain all 3 hosts, confirm
+ err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host2.ID, host3.ID})
+ require.NoError(t, err)
+
+ labels, err = ds.ListLabelsForHost(ctx, host1.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host2.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ labels, err = ds.ListLabelsForHost(ctx, host3.ID)
+ require.NoError(t, err)
+ require.Len(t, labels, 1)
+ require.Equal(t, "label1", labels[0].Name)
+
+ label, err = ds.GetLabelSpec(ctx, label1.Name)
+ require.NoError(t, err)
+ require.Len(t, label.Hosts, 3)
+ require.Equal(t, host1.Hostname, label.Hosts[0])
+ // 2 and 3 have same name
+ require.Equal(t, host2.Hostname, label.Hosts[1])
+ require.Equal(t, host3.Hostname, label.Hosts[2])
+}
diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go
index bfc24faf455f..e05965137fcd 100644
--- a/server/fleet/datastore.go
+++ b/server/fleet/datastore.go
@@ -189,6 +189,10 @@ type Datastore interface {
// If a host is already not a member of a label then such label will be ignored.
RemoveLabelsFromHost(ctx context.Context, hostID uint, labelIDs []uint) error
+ // UpdateLabelMembershipByHostIDs updates the label membership for the given label ID with host
+ // IDs, applied in batches
+ UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIds []uint) (err error)
+
NewLabel(ctx context.Context, Label *Label, opts ...OptionalArg) (*Label, error)
// SaveLabel updates the label and returns the label and an array of host IDs
// members of this label, or an error.
diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go
index 44e0c01f610d..e5306955b41a 100644
--- a/server/mock/datastore_mock.go
+++ b/server/mock/datastore_mock.go
@@ -129,6 +129,8 @@ type ListPacksForHostFunc func(ctx context.Context, hid uint) (packs []*fleet.Pa
type ApplyLabelSpecsFunc func(ctx context.Context, specs []*fleet.LabelSpec) error
+type UpdateLabelMembershipByHostIDsFunc func(ctx context.Context, labelID uint, hostIDs []uint) (err error)
+
type GetLabelSpecsFunc func(ctx context.Context) ([]*fleet.LabelSpec, error)
type GetLabelSpecFunc func(ctx context.Context, name string) (*fleet.LabelSpec, error)
@@ -1370,6 +1372,9 @@ type DataStore struct {
ApplyLabelSpecsFunc ApplyLabelSpecsFunc
ApplyLabelSpecsFuncInvoked bool
+ UpdateLabelMembershipByHostIDsFunc UpdateLabelMembershipByHostIDsFunc
+ UpdateLabelMembershipByHostIDsFuncInvoked bool
+
GetLabelSpecsFunc GetLabelSpecsFunc
GetLabelSpecsFuncInvoked bool
@@ -3368,6 +3373,13 @@ func (s *DataStore) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe
return s.ApplyLabelSpecsFunc(ctx, specs)
}
+func (s *DataStore) UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIDs []uint) (err error) {
+ s.mu.Lock()
+ s.UpdateLabelMembershipByHostIDsFuncInvoked = true
+ s.mu.Unlock()
+ return s.UpdateLabelMembershipByHostIDsFunc(ctx, labelID, hostIDs)
+}
+
func (s *DataStore) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) {
s.mu.Lock()
s.GetLabelSpecsFuncInvoked = true
diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go
index 861ad7e11e8b..c7fa9e838bd5 100644
--- a/server/service/integration_core_test.go
+++ b/server/service/integration_core_test.go
@@ -4243,6 +4243,25 @@ func (s *integrationTestSuite) TestLabels() {
assert.EqualValues(t, 3, modResp.Label.HostCount)
assert.Equal(t, newName, modResp.Label.Name)
+ // add a host with the same name as another host to manual label 2, confirm only one host is added
+ sameName, err := s.ds.NewHost(context.Background(), &fleet.Host{
+ HardwareSerial: "ABCDE",
+ Hostname: manualHosts[0].Hostname,
+ Platform: "darwin",
+ })
+ require.NoError(t, err)
+
+ modResp = modifyLabelResponse{}
+ s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/labels/%d", manualLbl2.ID),
+ &fleet.ModifyLabelPayload{Hosts: []string{sameName.HardwareSerial}}, http.StatusOK, &modResp)
+ assert.Len(t, modResp.Label.HostIDs, 1)
+ assert.NotEqual(t, manualHosts[0].ID, modResp.Label.HostIDs[0])
+ assert.Equal(t, manualLbl2.ID, modResp.Label.ID)
+ assert.Equal(t, fleet.LabelTypeRegular, modResp.Label.LabelType)
+ assert.Equal(t, fleet.LabelMembershipTypeManual, modResp.Label.LabelMembershipType)
+ assert.ElementsMatch(t, []uint{sameName.ID}, modResp.Label.HostIDs)
+ assert.EqualValues(t, 1, modResp.Label.HostCount)
+
// modify manual label 2 adding some hosts
modResp = modifyLabelResponse{}
newName = "modified_manual_label2"
diff --git a/server/service/labels.go b/server/service/labels.go
index 2b0ce6f59836..23f1362c66f1 100644
--- a/server/service/labels.go
+++ b/server/service/labels.go
@@ -169,7 +169,6 @@ func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.Modi
if label.LabelType == fleet.LabelTypeBuiltIn {
return nil, nil, fleet.NewInvalidArgumentError("label_type", fmt.Sprintf("cannot modify built-in label '%s'", label.Name))
}
- originalLabelName := label.Name
if payload.Name != nil {
// Check if the new name is a reserved label name
for name := range fleet.ReservedLabelNames() {
@@ -186,37 +185,20 @@ func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.Modi
return nil, nil, fleet.NewInvalidArgumentError("hosts", "cannot provide a list of hosts for a dynamic label")
}
- // if membership type is manual and the Hosts membership is provided, must
- // use ApplyLabelSpecs (as SaveLabel does not update label memberships),
- // otherwise SaveLabel works for dynamic membership. Must resolve the host
- // identifiers to hostname so that ApplySpecs can be used (it expects only
- // hostnames).
- if label.LabelMembershipType == fleet.LabelMembershipTypeManual && payload.Hosts != nil {
- spec := fleet.LabelSpec{
- Name: originalLabelName,
- Description: label.Description,
- Query: label.Query,
- Platform: label.Platform,
- LabelType: label.LabelType,
- LabelMembershipType: label.LabelMembershipType,
- }
- hostnames, err := svc.ds.HostnamesByIdentifiers(ctx, payload.Hosts)
+ // use SaveLabel to update label info, and UpdateLabelMembershipByHostIDs to update membership. Approach using label
+ // names and ApplyLabelSpecs doesn't work for multiple hosts with the same name.
+
+ if payload.Hosts != nil {
+ // get host ids for valid hosts. since this endpoint will contain hosts identified by serial
+ // number, there should be no duplicates
+
+ hostIds, err := svc.ds.HostIDsByIdentifier(ctx, filter, payload.Hosts)
if err != nil {
return nil, nil, err
}
- spec.Hosts = hostnames
- // Note: ApplyLabelSpecs cannot update label name since it uses the name as a key.
- // So, we must handle it later.
- if err := svc.ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{&spec}); err != nil {
+ if err := svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIds); err != nil {
return nil, nil, err
}
- // If the label name has changed, we must update it.
- if originalLabelName != label.Name {
- return svc.ds.SaveLabel(ctx, label, filter)
- }
- // Otherwise, simply reload label to get the host counts information
- ctx = ctxdb.RequirePrimary(ctx, true)
- return svc.ds.Label(ctx, id, filter)
}
return svc.ds.SaveLabel(ctx, label, filter)
}