Skip to content

Commit

Permalink
Automatic install custom packages (#25021)
Browse files Browse the repository at this point in the history
#24385

Some docs change here: #25026.

- [X] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/Committing-Changes.md#changes-files)
for more information.
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality
  • Loading branch information
lucasmrod authored Dec 27, 2024
1 parent 3881d0b commit 963cc7e
Show file tree
Hide file tree
Showing 18 changed files with 1,389 additions and 24 deletions.
1 change: 1 addition & 0 deletions changes/24385-automatic-install-custom-packages
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* Added capability to automatically generate "trigger policies" for custom software packages.
1 change: 0 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
---
version: "2"
services:
# To test with MariaDB, set FLEET_MYSQL_IMAGE to mariadb:10.6 or the like (note MariaDB is not
# officially supported).
Expand Down
35 changes: 28 additions & 7 deletions ee/server/service/software_installers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
return err
}

if payload.AutomaticInstall {
// Currently, same write permissions are applied on software and policies,
// but leaving this here in case it changes in the future.
if err := svc.authz.Authorize(ctx, &fleet.Policy{PolicyData: fleet.PolicyData{TeamID: payload.TeamID}}, fleet.ActionWrite); err != nil {
return err
}
}

// validate labels before we do anything else
validatedLabels, err := ValidateSoftwareLabels(ctx, svc, payload.LabelsIncludeAny, payload.LabelsExcludeAny)
if err != nil {
Expand All @@ -61,13 +69,29 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
return ctxerr.Wrap(ctx, err, "adding metadata to payload")
}

if payload.AutomaticInstall {
switch {
//
// For "msi", addMetadataToSoftwarePayload fails before this point if product code cannot be extracted.
//
case payload.Extension == "exe":
return &fleet.BadRequestError{
Message: "Couldn't add. Fleet can't create a policy to detect existing installations for .exe packages. Please add the software, add a custom policy, and enable the install software policy automation.",
}
case payload.Extension == "pkg" && payload.BundleIdentifier == "":
// For pkgs without bundle identifier the request usually fails before reaching this point,
// but addMetadataToSoftwarePayload may not fail if the package has "package IDs" but not a "bundle identifier",
// in which case we want to fail here because we cannot generate a policy without a bundle identifier.
return &fleet.BadRequestError{
Message: "Couldn't add. Policy couldn't be created because bundle identifier can't be extracted.",
}
}
}

if err := svc.storeSoftware(ctx, payload); err != nil {
return ctxerr.Wrap(ctx, err, "storing software installer")
}

// TODO: basic validation of install and post-install script (e.g., supported interpreters)?
// TODO: any validation of pre-install query?

// Update $PACKAGE_ID in uninstall script
preProcessUninstallScript(payload)

Expand All @@ -81,8 +105,6 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
}
level.Debug(svc.logger).Log("msg", "software installer uploaded", "installer_id", installerID)

// TODO: QA what breaks when you have a software title with no versions?

var teamName *string
if payload.TeamID != nil && *payload.TeamID != 0 {
t, err := svc.ds.Team(ctx, *payload.TeamID)
Expand All @@ -92,7 +114,6 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
teamName = &t.Name
}

// Create activity
actLabelsIncl, actLabelsExcl := activitySoftwareLabelsFromValidatedLabels(payload.ValidatedLabels)
if err := svc.NewActivity(ctx, vc.User, fleet.ActivityTypeAddedSoftware{
SoftwareTitle: payload.Title,
Expand Down Expand Up @@ -1241,7 +1262,7 @@ func (svc *Service) addMetadataToSoftwarePayload(ctx context.Context, payload *f

if len(meta.PackageIDs) == 0 {
return "", &fleet.BadRequestError{
Message: fmt.Sprintf("Couldn't add. Fleet couldn't read the package IDs, product code, or name from %s.", payload.Filename),
Message: "Couldn't add. Unable to extract necessary metadata.",
InternalErr: ctxerr.New(ctx, "extracting package IDs from installer metadata"),
}
}
Expand Down
116 changes: 116 additions & 0 deletions pkg/automatic_policy/automatic_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Package automatic_policy generates "trigger policies" from metadata of software packages.
package automatic_policy

import (
"errors"
"fmt"
)

// PolicyData contains generated data for a policy to trigger installation of a software package.
type PolicyData struct {
// Name is the generated name of the policy.
Name string
// Query is the generated SQL/sqlite of the policy.
Query string
// Description is the generated description for the policy.
Description string
// Platform is the target platform for the policy.
Platform string
}

// InstallerMetadata contains the metadata of a software package used to generate the policies.
type InstallerMetadata struct {
// Title is the software title extracted from a software package.
Title string
// Extension is the extension of the software package.
Extension string
// BundleIdentifier contains the bundle identifier for 'pkg' packages.
BundleIdentifier string
// PackageIDs contains the product code for 'msi' packages.
PackageIDs []string
}

var (
// ErrExtensionNotSupported is returned if the extension is not supported to generate automatic policies.
ErrExtensionNotSupported = errors.New("extension not supported")
// ErrMissingBundleIdentifier is returned if the software extension is "pkg" and a bundle identifier was not extracted from the installer.
ErrMissingBundleIdentifier = errors.New("missing bundle identifier")
// ErrMissingProductCode is returned if the software extension is "msi" and a product code was not extracted from the installer.
ErrMissingProductCode = errors.New("missing product code")
// ErrMissingTitle is returned if a title was not extracted from the installer.
ErrMissingTitle = errors.New("missing title")
)

// Generate generates the "trigger policy" from the metadata of a software package.
func Generate(metadata InstallerMetadata) (*PolicyData, error) {
switch {
case metadata.Title == "":
return nil, ErrMissingTitle
case metadata.Extension != "pkg" && metadata.Extension != "msi" && metadata.Extension != "deb" && metadata.Extension != "rpm":
return nil, ErrExtensionNotSupported
case metadata.Extension == "pkg" && metadata.BundleIdentifier == "":
return nil, ErrMissingBundleIdentifier
case metadata.Extension == "msi" && (len(metadata.PackageIDs) == 0 || metadata.PackageIDs[0] == ""):
return nil, ErrMissingProductCode
}

name := fmt.Sprintf("[Install software] %s (%s)", metadata.Title, metadata.Extension)

description := fmt.Sprintf("Policy triggers automatic install of %s on each host that's missing this software.", metadata.Title)
if metadata.Extension == "deb" || metadata.Extension == "rpm" {
basedPrefix := "RPM"
if metadata.Extension == "rpm" {
basedPrefix = "Debian"
}
description += fmt.Sprintf(
"\nSoftware won't be installed on Linux hosts with %s-based distributions because this policy's query is written to always pass on these hosts.",
basedPrefix,
)
}

switch metadata.Extension {
case "pkg":
return &PolicyData{
Name: name,
Query: fmt.Sprintf("SELECT 1 FROM apps WHERE bundle_identifier = '%s';", metadata.BundleIdentifier),
Platform: "darwin",
Description: description,
}, nil
case "msi":
return &PolicyData{
Name: name,
Query: fmt.Sprintf("SELECT 1 FROM programs WHERE identifying_number = '%s';", metadata.PackageIDs[0]),
Platform: "windows",
Description: description,
}, nil
case "deb":
return &PolicyData{
Name: name,
Query: fmt.Sprintf(
// First inner SELECT will mark the policies as successful on non-DEB-based hosts.
`SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM deb_packages) = 0
) OR EXISTS (
SELECT 1 FROM deb_packages WHERE name = '%s'
);`, metadata.Title,
),
Platform: "linux",
Description: description,
}, nil
case "rpm":
return &PolicyData{
Name: name,
Query: fmt.Sprintf(
// First inner SELECT will mark the policies as successful on non-RPM-based hosts.
`SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM rpm_packages) = 0
) OR EXISTS (
SELECT 1 FROM rpm_packages WHERE name = '%s'
);`, metadata.Title),
Platform: "linux",
Description: description,
}, nil
default:
return nil, ErrExtensionNotSupported
}
}
108 changes: 108 additions & 0 deletions pkg/automatic_policy/automatic_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package automatic_policy

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGenerateErrors(t *testing.T) {
_, err := Generate(InstallerMetadata{
Title: "Foobar",
Extension: "exe",
BundleIdentifier: "",
PackageIDs: []string{"Foobar"},
})
require.ErrorIs(t, err, ErrExtensionNotSupported)

_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingProductCode)
_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{},
})
require.ErrorIs(t, err, ErrMissingProductCode)

_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "pkg",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingBundleIdentifier)

_, err = Generate(InstallerMetadata{
Title: "",
Extension: "deb",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingTitle)
}

func TestGenerate(t *testing.T) {
policyData, err := Generate(InstallerMetadata{
Title: "Foobar",
Extension: "pkg",
BundleIdentifier: "com.foo.bar",
PackageIDs: []string{"com.foo.bar"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Foobar (pkg)", policyData.Name)
require.Equal(t, "Policy triggers automatic install of Foobar on each host that's missing this software.", policyData.Description)
require.Equal(t, "darwin", policyData.Platform)
require.Equal(t, "SELECT 1 FROM apps WHERE bundle_identifier = 'com.foo.bar';", policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Barfoo",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{"foo"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Barfoo (msi)", policyData.Name)
require.Equal(t, "Policy triggers automatic install of Barfoo on each host that's missing this software.", policyData.Description)
require.Equal(t, "windows", policyData.Platform)
require.Equal(t, "SELECT 1 FROM programs WHERE identifying_number = 'foo';", policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Zoobar",
Extension: "deb",
BundleIdentifier: "",
PackageIDs: []string{"Zoobar"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Zoobar (deb)", policyData.Name)
require.Equal(t, `Policy triggers automatic install of Zoobar on each host that's missing this software.
Software won't be installed on Linux hosts with RPM-based distributions because this policy's query is written to always pass on these hosts.`, policyData.Description)
require.Equal(t, "linux", policyData.Platform)
require.Equal(t, `SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM deb_packages) = 0
) OR EXISTS (
SELECT 1 FROM deb_packages WHERE name = 'Zoobar'
);`, policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Barzoo",
Extension: "rpm",
BundleIdentifier: "",
PackageIDs: []string{"Barzoo"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Barzoo (rpm)", policyData.Name)
require.Equal(t, `Policy triggers automatic install of Barzoo on each host that's missing this software.
Software won't be installed on Linux hosts with Debian-based distributions because this policy's query is written to always pass on these hosts.`, policyData.Description)
require.Equal(t, "linux", policyData.Platform)
require.Equal(t, `SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM rpm_packages) = 0
) OR EXISTS (
SELECT 1 FROM rpm_packages WHERE name = 'Barzoo'
);`, policyData.Query)
}
22 changes: 13 additions & 9 deletions server/datastore/mysql/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo
}

if p.TeamID != nil {
if err := ds.assertTeamMatches(ctx, *p.TeamID, p.SoftwareInstallerID, p.ScriptID); err != nil {
if err := assertTeamMatches(ctx, ds.writer(ctx), *p.TeamID, p.SoftwareInstallerID, p.ScriptID); err != nil {
return ctxerr.Wrap(ctx, err, "save policy")
}
}
Expand Down Expand Up @@ -185,10 +185,10 @@ var (
errMismatchedScriptTeam = &fleet.BadRequestError{Message: "script is associated with a different team"}
)

func (ds *Datastore) assertTeamMatches(ctx context.Context, teamID uint, softwareInstallerID *uint, scriptID *uint) error {
func assertTeamMatches(ctx context.Context, db sqlx.QueryerContext, teamID uint, softwareInstallerID *uint, scriptID *uint) error {
if softwareInstallerID != nil {
var softwareInstallerTeamID uint
err := sqlx.GetContext(ctx, ds.reader(ctx), &softwareInstallerTeamID, "SELECT global_or_team_id FROM software_installers WHERE id = ?", softwareInstallerID)
err := sqlx.GetContext(ctx, db, &softwareInstallerTeamID, "SELECT global_or_team_id FROM software_installers WHERE id = ?", softwareInstallerID)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -202,7 +202,7 @@ func (ds *Datastore) assertTeamMatches(ctx context.Context, teamID uint, softwar

if scriptID != nil {
var scriptTeamID uint
err := sqlx.GetContext(ctx, ds.reader(ctx), &scriptTeamID, "SELECT global_or_team_id FROM scripts WHERE id = ?", scriptID)
err := sqlx.GetContext(ctx, db, &scriptTeamID, "SELECT global_or_team_id FROM scripts WHERE id = ?", scriptID)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand Down Expand Up @@ -647,8 +647,12 @@ func (ds *Datastore) PolicyQueriesForHost(ctx context.Context, host *fleet.Host)
}

func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) {
return newTeamPolicy(ctx, ds.writer(ctx), teamID, authorID, args)
}

func newTeamPolicy(ctx context.Context, db sqlx.ExtContext, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) {
if args.QueryID != nil {
q, err := ds.Query(ctx, *args.QueryID)
q, err := query(ctx, db, *args.QueryID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "fetching query from id")
}
Expand All @@ -659,7 +663,7 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
// Check team exists.
if teamID > 0 {
var ok bool
err := ds.writer(ctx).GetContext(ctx, &ok, `SELECT COUNT(*) = 1 FROM teams WHERE id = ?`, teamID)
err := sqlx.GetContext(ctx, db, &ok, `SELECT COUNT(*) = 1 FROM teams WHERE id = ?`, teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get team id")
}
Expand All @@ -671,11 +675,11 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
// We must normalize the name for full Unicode support (Unicode equivalence).
nameUnicode := norm.NFC.String(args.Name)

if err := ds.assertTeamMatches(ctx, teamID, args.SoftwareInstallerID, args.ScriptID); err != nil {
if err := assertTeamMatches(ctx, db, teamID, args.SoftwareInstallerID, args.ScriptID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "create team policy")
}

res, err := ds.writer(ctx).ExecContext(ctx,
res, err := db.ExecContext(ctx,
fmt.Sprintf(
`INSERT INTO policies (name, query, description, team_id, resolution, author_id, platforms, critical, calendar_events_enabled, software_installer_id, script_id, checksum) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, %s)`,
policiesChecksumComputedColumn(),
Expand All @@ -695,7 +699,7 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting last id after inserting policy")
}
return policyDB(ctx, ds.writer(ctx), uint(lastIdInt64), &teamID) //nolint:gosec // dismiss G115
return policyDB(ctx, db, uint(lastIdInt64), &teamID) //nolint:gosec // dismiss G115
}

func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) {
Expand Down
Loading

0 comments on commit 963cc7e

Please sign in to comment.