Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to populate users and labels on list hosts endpoint #25621

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/22464-list-hosts-populate-users-labels
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Added option to populate users and labels on list hosts endpoint
9 changes: 9 additions & 0 deletions server/datastore/mysql/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,15 @@ func loadHostUsersDB(ctx context.Context, db sqlx.QueryerContext, hostID uint) (
return users, nil
}

func (ds *Datastore) ListHostUsers(ctx context.Context, hostID uint) ([]fleet.HostUser, error) {
users, err := loadHostUsersDB(ctx, ds.reader(ctx), hostID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "loading host users")
}

return users, nil
}

// hostRefs are the tables referenced by hosts.
//
// Defined here for testing purposes.
Expand Down
3 changes: 3 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ type Datastore interface {
// the implementation for the exact list.
ListHostsLiteByIDs(ctx context.Context, ids []uint) ([]*Host, error)

// ListHostUsers returns a list of users that are currently on the host
ListHostUsers(ctx context.Context, hostID uint) ([]HostUser, error)

MarkHostsSeen(ctx context.Context, hostIDs []uint, t time.Time) error
SearchHosts(ctx context.Context, filter TeamFilter, query string, omit ...uint) ([]*Host, error)
// EnrolledHostIDs returns the full list of enrolled host IDs.
Expand Down
2 changes: 1 addition & 1 deletion server/fleet/hostresponse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type HostResponse struct {
Status HostStatus `json:"status" csv:"status"`
DisplayText string `json:"display_text" csv:"display_text"`
DisplayName string `json:"display_name" csv:"display_name"`
Labels []Label `json:"labels,omitempty" csv:"-"`
Labels []*Label `json:"labels,omitempty" csv:"-"`
Geolocation *GeoLocation `json:"geolocation,omitempty" csv:"-"`
CSVDeviceMapping string `json:"-" db:"-" csv:"device_mapping"`
}
Expand Down
6 changes: 6 additions & 0 deletions server/fleet/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ type HostListOptions struct {
// PopulatePolicies adds the `Policies` array field to all Hosts returned.
PopulatePolicies bool

// PopulateUsers adds the `Users` array field to all Hosts returned
PopulateUsers bool

// PopulateLabels adds the `Labels` array field to all host responses returned
PopulateLabels bool

// VulnerabilityFilter filters the hosts by the presence of a vulnerability (CVE)
VulnerabilityFilter *string

Expand Down
3 changes: 3 additions & 0 deletions server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ type Service interface {
// ListHostsInLabel returns a slice of hosts in the label with the given ID.
ListHostsInLabel(ctx context.Context, lid uint, opt HostListOptions) ([]*Host, error)

// ListLabelsForHost returns a slice of labels for a given host
ListLabelsForHost(ctx context.Context, hostID uint) ([]*Label, error)

// BatchValidateLabels validates that each of the provided label names exists. The returned map
// is keyed by label name. Caller must ensure that appropirate authorization checks are
// performed prior to calling this method.
Expand Down
12 changes: 12 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ type ListHostsLiteByUUIDsFunc func(ctx context.Context, filter fleet.TeamFilter,

type ListHostsLiteByIDsFunc func(ctx context.Context, ids []uint) ([]*fleet.Host, error)

type ListHostUsersFunc func(ctx context.Context, hostID uint) ([]fleet.HostUser, error)

type MarkHostsSeenFunc func(ctx context.Context, hostIDs []uint, t time.Time) error

type SearchHostsFunc func(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error)
Expand Down Expand Up @@ -1452,6 +1454,9 @@ type DataStore struct {
ListHostsLiteByIDsFunc ListHostsLiteByIDsFunc
ListHostsLiteByIDsFuncInvoked bool

ListHostUsersFunc ListHostUsersFunc
ListHostUsersFuncInvoked bool

MarkHostsSeenFunc MarkHostsSeenFunc
MarkHostsSeenFuncInvoked bool

Expand Down Expand Up @@ -3559,6 +3564,13 @@ func (s *DataStore) ListHostsLiteByIDs(ctx context.Context, ids []uint) ([]*flee
return s.ListHostsLiteByIDsFunc(ctx, ids)
}

func (s *DataStore) ListHostUsers(ctx context.Context, hostID uint) ([]fleet.HostUser, error) {
s.mu.Lock()
s.ListHostUsersFuncInvoked = true
s.mu.Unlock()
return s.ListHostUsersFunc(ctx, hostID)
}

func (s *DataStore) MarkHostsSeen(ctx context.Context, hostIDs []uint, t time.Time) error {
s.mu.Lock()
s.MarkHostsSeenFuncInvoked = true
Expand Down
27 changes: 27 additions & 0 deletions server/service/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,16 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi
for i, host := range hosts {
h := fleet.HostResponseForHost(ctx, svc, host)
hostResponses[i] = *h

if req.Opts.PopulateLabels {
labels, err := svc.ListLabelsForHost(ctx, h.ID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, fmt.Sprintf("failed to list labels for host %d", h.ID))
}
hostResponses[i].Labels = labels
}
}

return listHostsResponse{
Hosts: hostResponses,
Software: software,
Expand Down Expand Up @@ -228,6 +237,16 @@ func (svc *Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([
}
}

if opt.PopulateUsers {
for _, host := range hosts {
hu, err := svc.ds.ListHostUsers(ctx, host.ID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, fmt.Sprintf("get users for host %d", host.ID))
}
host.Users = hu
}
}

return hosts, nil
}

Expand Down Expand Up @@ -1947,6 +1966,14 @@ func hostsReportEndpoint(ctx context.Context, request interface{}, svc fleet.Ser
return hostsReportResponse{Columns: cols, Hosts: hostResps}, nil
}

func (svc *Service) ListLabelsForHost(ctx context.Context, hostID uint) ([]*fleet.Label, error) {
// require list hosts permission to view this information
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
return svc.ds.ListLabelsForHost(ctx, hostID)
dantecatalfamo marked this conversation as resolved.
Show resolved Hide resolved
}

type osVersionsRequest struct {
fleet.ListOptions
TeamID *uint `query:"team_id,optional"`
Expand Down
55 changes: 55 additions & 0 deletions server/service/integration_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,61 @@ func (s *integrationTestSuite) TestListHosts() {
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "query", " local0 ")
require.Len(t, resp.Hosts, 1)
require.Contains(t, resp.Hosts[0].Hostname, "local0")

// Add users to hosts
users := []fleet.HostUser{
{
Uid: 1,
Username: "root",
Type: "local",
GroupName: "root",
Shell: "/bin/sh",
},
{
Uid: 1001,
Username: "username",
Type: "local",
GroupName: "usergroup",
Shell: "/bin/sh",
},
}
err = s.ds.SaveHostUsers(ctx, host0.ID, users)
require.NoError(t, err)

// Add labels to host
label1, err := s.ds.NewLabel(ctx, &fleet.Label{Name: "First Label"})
require.NoError(t, err)
label2, err := s.ds.NewLabel(ctx, &fleet.Label{Name: "Second Label"})
require.NoError(t, err)

err = s.ds.AddLabelsToHost(ctx, host0.ID, []uint{label1.ID, label2.ID})
require.NoError(t, err)

// Without "populate_users" and "populate_labels" query params, no users or labels
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "query", "local0")
require.Len(t, resp.Hosts, 1)
require.Contains(t, resp.Hosts[0].Hostname, "local0")
require.Empty(t, resp.Hosts[0].Users)
require.Empty(t, resp.Hosts[0].Labels)

// With "populate_users" query param
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "query", "local0", "populate_users", "true")
require.Len(t, resp.Hosts, 1)
require.Contains(t, resp.Hosts[0].Hostname, "local0")
require.Len(t, resp.Hosts[0].Users, 2)
require.EqualValues(t, resp.Hosts[0].Users[0], users[0])
require.EqualValues(t, resp.Hosts[0].Users[1], users[1])

// With "populate_labels" query param
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "query", "local0", "populate_labels", "true")
require.Len(t, resp.Hosts, 1)
require.Contains(t, resp.Hosts[0].Hostname, "local0")
require.Len(t, resp.Hosts[0].Labels, 2)
require.Equal(t, label1.Name, resp.Hosts[0].Labels[0].Name)
require.Equal(t, label2.Name, resp.Hosts[0].Labels[1].Name)
}

func (s *integrationTestSuite) TestInvites() {
Expand Down
22 changes: 22 additions & 0 deletions server/service/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,28 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
hopt.PopulatePolicies = pp
}

populateUsers := r.URL.Query().Get("populate_users")
if populateUsers != "" {
pu, err := strconv.ParseBool(populateUsers)
if err != nil {
return hopt, ctxerr.Wrap(
r.Context(), badRequest(fmt.Sprintf("Invalid boolean parameter populate_users: %s", populateUsers)),
)
}
hopt.PopulateUsers = pu
}

populateLabels := r.URL.Query().Get("populate_labels")
if populateLabels != "" {
pl, err := strconv.ParseBool(populateLabels)
if err != nil {
return hopt, ctxerr.Wrap(
r.Context(), badRequest(fmt.Sprintf("Invalid boolean parameter populate_labels: %s", populateLabels)),
)
}
hopt.PopulateLabels = pl
}

// cannot combine software_id, software_version_id, and software_title_id
var softwareErrorLabel []string
if hopt.SoftwareIDFilter != nil {
Expand Down
Loading