Skip to content

Commit

Permalink
Team user should not access OS version on another team. (#17347)
Browse files Browse the repository at this point in the history
#17117 
For `fleet/os_versions` and `/fleet/os_versions/[id]`, team users can no
longer access os versions on hosts from other teams.

### Team admin /os_versions - only returns os versions for the user's
team(s)
GET https://localhost:8080/api/v1/fleet/os_versions

### Team admin /os_versions/:id on 'No Team' - 403
GET https://localhost:8080/api/v1/fleet/os_versions/5

### Global admin /os_versions/:id?team_id does not exist anywhere - 404
GET https://localhost:8080/api/v1/fleet/os_versions/999999?team_id=1

# Checklist for submitter

<!-- Note that API documentation changes are now addressed by the
product design team. -->

- [x] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
  • Loading branch information
getvictor authored Mar 13, 2024
1 parent 8d8181e commit ad5c0a9
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 78 deletions.
1 change: 1 addition & 0 deletions changes/17347-team-user-os-version-restrict
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
For GET fleet/os_versions and GET fleet/os_versions/[id], team users no longer have access to os versions on hosts from other teams.
8 changes: 6 additions & 2 deletions cmd/fleet/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ func TestCronVulnerabilitiesCreatesDatabasesPath(t *testing.T) {
// we should not get this far before we see the directory being created
return nil, errors.New("shouldn't happen")
}
ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) {
ds.OSVersionsFunc = func(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
return &fleet.OSVersions{}, nil
}
ds.SyncHostsSoftwareFunc = func(ctx context.Context, updatedAt time.Time) error {
Expand Down Expand Up @@ -452,7 +454,9 @@ func TestScanVulnerabilities(t *testing.T) {
ds.DeleteOutOfDateVulnerabilitiesFunc = func(ctx context.Context, source fleet.VulnerabilitySource, duration time.Duration) error {
return nil
}
ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) {
ds.OSVersionsFunc = func(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
return &fleet.OSVersions{
CountsUpdatedAt: time.Now(),
OSVersions: []fleet.OSVersion{
Expand Down
44 changes: 26 additions & 18 deletions server/datastore/mysql/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4464,8 +4464,10 @@ func (ds *Datastore) UpdateHost(ctx context.Context, host *fleet.Host) error {
)
}

func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) {
jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamID)
func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) (
*fleet.OSVersion, *time.Time, error,
) {
jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamFilter)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, notFound("OSVersion")
Expand Down Expand Up @@ -4510,15 +4512,17 @@ func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamID *ui
// counts for the same macOS version on x86_64 and arm64 architectures are counted together.
// Results can be filtered using the following optional criteria: team id, platform, or name and
// version. Name cannot be used without version, and conversely, version cannot be used without name.
func (ds *Datastore) OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) {
func (ds *Datastore) OSVersions(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
if name != nil && version == nil {
return nil, errors.New("invalid usage: cannot filter by name without version")
}
if name == nil && version != nil {
return nil, errors.New("invalid usage: cannot filter by version without name")
}

jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamID)
jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamFilter)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -4568,30 +4572,34 @@ func (ds *Datastore) OSVersions(ctx context.Context, teamID *uint, platform *str
return res, nil
}

func (ds *Datastore) executeOSVersionQuery(ctx context.Context, teamID *uint) (*json.RawMessage, time.Time, error) {
func (ds *Datastore) executeOSVersionQuery(ctx context.Context, teamFilter *fleet.TeamFilter) (
*json.RawMessage, time.Time, error,
) {
query := `
SELECT
json_value,
updated_at
FROM aggregated_stats
WHERE
id = ? AND
global_stats = ? AND
type = ?
WHERE type = ?
`
args := []interface{}{aggregatedStatsTypeOSVersions}
switch {
case teamFilter != nil && teamFilter.TeamID != nil:
query += " AND id = ? AND global_stats = ?"
args = append(args, *teamFilter.TeamID, false)
case teamFilter != nil:
query += " AND " + ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(
*teamFilter, "global_stats = 1 AND id = 0", "global_stats = 0 AND id",
)
default:
query += " AND id = ? AND global_stats = ?"
args = append(args, 0, true)
}
var row struct {
JSONValue *json.RawMessage `db:"json_value"`
UpdatedAt time.Time `db:"updated_at"`
}

id := uint(0)
globalStats := true
if teamID != nil {
id = *teamID
globalStats = false
}

err := sqlx.GetContext(ctx, ds.reader(ctx), &row, query, id, globalStats, aggregatedStatsTypeOSVersions)
err := sqlx.GetContext(ctx, ds.reader(ctx), &row, query, args...)
if err != nil {
if err == sql.ErrNoRows {
return nil, time.Time{}, ctxerr.Wrap(ctx, notFound("OSVersion"))
Expand Down
32 changes: 23 additions & 9 deletions server/datastore/mysql/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6275,7 +6275,8 @@ func testOSVersions(t *testing.T, ds *Datastore) {
require.Equal(t, &expected[0], osVersion)

// team 1
osVersions, err = ds.OSVersions(ctx, &team1.ID, nil, nil, nil)
userAdmin := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}
osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team1.ID, User: userAdmin}, nil, nil, nil)
require.NoError(t, err)

expected = []fleet.OSVersion{
Expand All @@ -6284,16 +6285,25 @@ func testOSVersions(t *testing.T, ds *Datastore) {
}
require.Equal(t, expected, osVersions.OSVersions)

osVersion, _, err = ds.OSVersion(ctx, 5, &team1.ID)
osVersion, _, err = ds.OSVersion(ctx, 5, &fleet.TeamFilter{TeamID: &team1.ID})
require.NoError(t, err)
require.Equal(t, &expected[0], osVersion)

osVersion, _, err = ds.OSVersion(ctx, 2, &team1.ID)
osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team1.ID, User: userAdmin})
require.NoError(t, err)
require.Equal(t, &expected[1], osVersion)

userTeam1 := &fleet.User{Teams: []fleet.UserTeam{{Team: *team1, Role: fleet.RoleAdmin}}}
osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{User: userTeam1}, nil, nil, nil)
require.NoError(t, err)
require.Equal(t, expected, osVersions.OSVersions)

osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{User: userTeam1})
require.NoError(t, err)
require.Equal(t, &expected[1], osVersion)

// team 2
osVersions, err = ds.OSVersions(ctx, &team2.ID, nil, nil, nil)
osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team2.ID}, nil, nil, nil)
require.NoError(t, err)

expected = []fleet.OSVersion{
Expand All @@ -6302,26 +6312,30 @@ func testOSVersions(t *testing.T, ds *Datastore) {
}
require.Equal(t, expected, osVersions.OSVersions)

osVersion, _, err = ds.OSVersion(ctx, 2, &team2.ID)
osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team2.ID})
require.NoError(t, err)
require.Equal(t, &expected[0], osVersion)

osVersion, _, err = ds.OSVersion(ctx, 3, &team2.ID)
osVersion, _, err = ds.OSVersion(ctx, 3, &fleet.TeamFilter{TeamID: &team2.ID})
require.NoError(t, err)
require.Equal(t, &expected[1], osVersion)

// Wrong team
_, _, err = ds.OSVersion(ctx, 3, &fleet.TeamFilter{User: userTeam1})
require.True(t, fleet.IsNotFound(err))

// team 3 (no hosts assigned to team)
osVersions, err = ds.OSVersions(ctx, &team3.ID, nil, nil, nil)
osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team3.ID}, nil, nil, nil)
require.NoError(t, err)
expected = []fleet.OSVersion{}
require.Equal(t, expected, osVersions.OSVersions)

osVersion, _, err = ds.OSVersion(ctx, 2, &team3.ID)
osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team3.ID})
require.Error(t, err)
require.Nil(t, osVersion)

// non-existent team
_, err = ds.OSVersions(ctx, ptr.Uint(404), nil, nil, nil)
_, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: ptr.Uint(404)}, nil, nil, nil)
require.Error(t, err)

// new host with arm64
Expand Down
14 changes: 11 additions & 3 deletions server/datastore/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,14 @@ func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey st
// filterTableAlias is the name/alias of the table to use in generating the
// SQL.
func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, filterTableAlias string) string {
globalFilter := fmt.Sprintf("%s.team_id = 0", filterTableAlias)
teamIDFilter := fmt.Sprintf("%s.team_id", filterTableAlias)
return ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(filter, globalFilter, teamIDFilter)
}

func (ds *Datastore) whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(
filter fleet.TeamFilter, globalSqlFilter string, teamIDSqlFilter string,
) string {
if filter.User == nil {
// This is likely unintentional, however we would like to return no
// results rather than panicking or returning some other error. At least
Expand All @@ -897,9 +905,9 @@ func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, f
return "FALSE"
}

defaultAllowClause := fmt.Sprintf("%s.team_id = 0", filterTableAlias)
defaultAllowClause := globalSqlFilter
if filter.TeamID != nil {
defaultAllowClause = fmt.Sprintf("%s.team_id = %d", filterTableAlias, *filter.TeamID)
defaultAllowClause = fmt.Sprintf("%s = %d", teamIDSqlFilter, *filter.TeamID)
}

if filter.User.GlobalRole != nil {
Expand Down Expand Up @@ -944,7 +952,7 @@ func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, f
return "FALSE"
}

return fmt.Sprintf("%s.team_id IN (%s)", filterTableAlias, strings.Join(idStrs, ","))
return fmt.Sprintf("%s IN (%s)", teamIDSqlFilter, strings.Join(idStrs, ","))
}

// whereFilterTeams returns the appropriate condition to use in the WHERE
Expand Down
6 changes: 5 additions & 1 deletion server/datastore/mysql/vulnerabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint
}

func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) (vos []*fleet.VulnerableOS, updatedAt time.Time, err error) {
osvs, err := ds.OSVersions(ctx, teamID, nil, nil, nil)
var teamFilter *fleet.TeamFilter
if teamID != nil {
teamFilter = &fleet.TeamFilter{TeamID: teamID}
}
osvs, err := ds.OSVersions(ctx, teamFilter, nil, nil, nil)
if err != nil && !fleet.IsNotFound(err) {
return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching team OS versions")
}
Expand Down
6 changes: 4 additions & 2 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ type Datastore interface {
GetMunkiIssue(ctx context.Context, munkiIssueID uint) (*MunkiIssue, error)
GetMDMSolution(ctx context.Context, mdmID uint) (*MDMSolution, error)

OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*OSVersions, error)
OSVersions(ctx context.Context, teamFilter *TeamFilter, platform *string, name *string, version *string) (*OSVersions, error)
OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) ([]*VulnerableOS, time.Time, error)
SoftwareByCVE(ctx context.Context, cve string, teamID *uint) ([]*VulnerableSoftware, time.Time, error)
OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*OSVersion, *time.Time, error)
// OSVersion returns the OSVersion with the provided ID. If teamFilter is not nil, then the OSVersion is filtered.
// The returned OSVersion is accompanied by the time it was last updated.
OSVersion(ctx context.Context, osVersionID uint, teamFilter *TeamFilter) (*OSVersion, *time.Time, error)
UpdateOSVersions(ctx context.Context) error

///////////////////////////////////////////////////////////////////////////////
Expand Down
12 changes: 6 additions & 6 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ type GetMunkiIssueFunc func(ctx context.Context, munkiIssueID uint) (*fleet.Munk

type GetMDMSolutionFunc func(ctx context.Context, mdmID uint) (*fleet.MDMSolution, error)

type OSVersionsFunc func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error)
type OSVersionsFunc func(ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string) (*fleet.OSVersions, error)

type OSVersionsByCVEFunc func(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableOS, time.Time, error)

type SoftwareByCVEFunc func(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableSoftware, time.Time, error)

type OSVersionFunc func(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error)
type OSVersionFunc func(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) (*fleet.OSVersion, *time.Time, error)

type UpdateOSVersionsFunc func(ctx context.Context) error

Expand Down Expand Up @@ -2908,11 +2908,11 @@ func (s *DataStore) GetMDMSolution(ctx context.Context, mdmID uint) (*fleet.MDMS
return s.GetMDMSolutionFunc(ctx, mdmID)
}

func (s *DataStore) OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) {
func (s *DataStore) OSVersions(ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string) (*fleet.OSVersions, error) {
s.mu.Lock()
s.OSVersionsFuncInvoked = true
s.mu.Unlock()
return s.OSVersionsFunc(ctx, teamID, platform, name, version)
return s.OSVersionsFunc(ctx, teamFilter, platform, name, version)
}

func (s *DataStore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableOS, time.Time, error) {
Expand All @@ -2929,11 +2929,11 @@ func (s *DataStore) SoftwareByCVE(ctx context.Context, cve string, teamID *uint)
return s.SoftwareByCVEFunc(ctx, cve, teamID)
}

func (s *DataStore) OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) {
func (s *DataStore) OSVersion(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) (*fleet.OSVersion, *time.Time, error) {
s.mu.Lock()
s.OSVersionFuncInvoked = true
s.mu.Unlock()
return s.OSVersionFunc(ctx, osVersionID, teamID)
return s.OSVersionFunc(ctx, osVersionID, teamFilter)
}

func (s *DataStore) UpdateOSVersions(ctx context.Context) error {
Expand Down
61 changes: 49 additions & 12 deletions server/service/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -1816,17 +1816,33 @@ func (svc *Service) OSVersions(ctx context.Context, teamID *uint, platform *stri
return nil, count, nil, &fleet.BadRequestError{Message: "Invalid order key"}
}

osVersions, err := svc.ds.OSVersions(ctx, teamID, platform, name, version)
if err != nil && fleet.IsNotFound(err) {
// differentiate case where team was added after UpdateOSVersions last ran
if teamID != nil && *teamID > 0 {
// most of the time, team should exist so checking here saves unnecessary db calls
_, err := svc.ds.Team(ctx, *teamID)
if err != nil {
return nil, count, nil, err
}
if teamID != nil {
// This auth check ensures we return 403 if the user doesn't have access to the team
if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil {
return nil, count, nil, err
}
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, count, nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, count, nil, fleet.NewInvalidArgumentError("team_id", fmt.Sprintf("team %d does not exist", *teamID)).
WithStatus(http.StatusNotFound)
}
// if team exists but stats have not yet been gathered, return empty JSON array
}

vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, count, nil, fleet.ErrNoContext
}
osVersions, err := svc.ds.OSVersions(
ctx, &fleet.TeamFilter{
User: vc.User,
IncludeObserver: true,
TeamID: teamID,
}, platform, name, version,
)
if err != nil && fleet.IsNotFound(err) {
// It is possible that os exists, but aggregation job has not run yet.
osVersions = &fleet.OSVersions{}
} else if err != nil {
return nil, count, nil, err
Expand Down Expand Up @@ -1913,15 +1929,36 @@ func (svc *Service) OSVersion(ctx context.Context, osID uint, teamID *uint, incl
}

if teamID != nil {
// This auth check ensures we return 403 if the user doesn't have access to the team
if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil {
return nil, nil, err
}
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil)
return nil, nil, fleet.NewInvalidArgumentError("team_id", fmt.Sprintf("team %d does not exist", *teamID)).
WithStatus(http.StatusNotFound)
}
}
osVersion, updateTime, err := svc.ds.OSVersion(ctx, osID, teamID)

vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, nil, fleet.ErrNoContext
}
osVersion, updateTime, err := svc.ds.OSVersion(
ctx, osID, &fleet.TeamFilter{
User: vc.User,
IncludeObserver: true,
TeamID: teamID,
},
)
if err != nil {
if fleet.IsNotFound(err) {
// We return an empty result here to be consistent with the fleet/os_versions behavior.
// It is possible the os version exists, but the aggregation job has not run yet.
return nil, nil, nil
}
return nil, nil, err
}

Expand Down
Loading

0 comments on commit ad5c0a9

Please sign in to comment.