Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic install custom packages #25021

Merged
merged 8 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/24385-automatic-install-custom-packages
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* Added capability to automatically generate "trigger policies" for custom software packages.
1 change: 0 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
---
version: "2"
services:
# To test with MariaDB, set FLEET_MYSQL_IMAGE to mariadb:10.6 or the like (note MariaDB is not
# officially supported).
Expand Down
35 changes: 28 additions & 7 deletions ee/server/service/software_installers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
return err
}

if payload.AutomaticInstall {
// Currently, same write permissions are applied on software and policies,
// but leaving this here in case it changes in the future.
if err := svc.authz.Authorize(ctx, &fleet.Policy{PolicyData: fleet.PolicyData{TeamID: payload.TeamID}}, fleet.ActionWrite); err != nil {
return err
}
}

// validate labels before we do anything else
validatedLabels, err := ValidateSoftwareLabels(ctx, svc, payload.LabelsIncludeAny, payload.LabelsExcludeAny)
if err != nil {
Expand All @@ -61,13 +69,29 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
return ctxerr.Wrap(ctx, err, "adding metadata to payload")
}

if payload.AutomaticInstall {
switch {
//
// For "msi", addMetadataToSoftwarePayload fails before this point if product code cannot be extracted.
//
case payload.Extension == "exe":
return &fleet.BadRequestError{
Message: "Couldn't add. Fleet can't create a policy to detect existing installations for .exe packages. Please add the software, add a custom policy, and enable the install software policy automation.",
}
case payload.Extension == "pkg" && payload.BundleIdentifier == "":
// For pkgs without bundle identifier the request usually fails before reaching this point,
// but addMetadataToSoftwarePayload may not fail if the package has "package IDs" but not a "bundle identifier",
// in which case we want to fail here because we cannot generate a policy without a bundle identifier.
return &fleet.BadRequestError{
Message: "Couldn't add. Policy couldn't be created because bundle identifier can't be extracted.",
}
}
}

if err := svc.storeSoftware(ctx, payload); err != nil {
return ctxerr.Wrap(ctx, err, "storing software installer")
}

// TODO: basic validation of install and post-install script (e.g., supported interpreters)?
// TODO: any validation of pre-install query?

// Update $PACKAGE_ID in uninstall script
preProcessUninstallScript(payload)

Expand All @@ -81,8 +105,6 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
}
level.Debug(svc.logger).Log("msg", "software installer uploaded", "installer_id", installerID)

// TODO: QA what breaks when you have a software title with no versions?

var teamName *string
if payload.TeamID != nil && *payload.TeamID != 0 {
t, err := svc.ds.Team(ctx, *payload.TeamID)
Expand All @@ -92,7 +114,6 @@ func (svc *Service) UploadSoftwareInstaller(ctx context.Context, payload *fleet.
teamName = &t.Name
}

// Create activity
actLabelsIncl, actLabelsExcl := activitySoftwareLabelsFromValidatedLabels(payload.ValidatedLabels)
if err := svc.NewActivity(ctx, vc.User, fleet.ActivityTypeAddedSoftware{
SoftwareTitle: payload.Title,
Expand Down Expand Up @@ -1235,7 +1256,7 @@ func (svc *Service) addMetadataToSoftwarePayload(ctx context.Context, payload *f

if len(meta.PackageIDs) == 0 {
return "", &fleet.BadRequestError{
Message: fmt.Sprintf("Couldn't add. Fleet couldn't read the package IDs, product code, or name from %s.", payload.Filename),
Message: "Couldn't add. Unable to extract necessary metadata.",
InternalErr: ctxerr.New(ctx, "extracting package IDs from installer metadata"),
}
}
Expand Down
116 changes: 116 additions & 0 deletions pkg/automatic_policy/automatic_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Package automatic_policy generates "trigger policies" from metadata of software packages.
package automatic_policy

import (
"errors"
"fmt"
)

// PolicyData contains generated data for a policy to trigger installation of a software package.
type PolicyData struct {
// Name is the generated name of the policy.
Name string
// Query is the generated SQL/sqlite of the policy.
Query string
// Description is the generated description for the policy.
Description string
// Platform is the target platform for the policy.
Platform string
}

// InstallerMetadata contains the metadata of a software package used to generate the policies.
type InstallerMetadata struct {
// Title is the software title extracted from a software package.
Title string
// Extension is the extension of the software package.
Extension string
// BundleIdentifier contains the bundle identifier for 'pkg' packages.
BundleIdentifier string
// PackageIDs contains the product code for 'msi' packages.
PackageIDs []string
}

var (
// ErrExtensionNotSupported is returned if the extension is not supported to generate automatic policies.
ErrExtensionNotSupported = errors.New("extension not supported")
// ErrMissingBundleIdentifier is returned if the software extension is "pkg" and a bundle identifier was not extracted from the installer.
ErrMissingBundleIdentifier = errors.New("missing bundle identifier")
// ErrMissingProductCode is returned if the software extension is "msi" and a product code was not extracted from the installer.
ErrMissingProductCode = errors.New("missing product code")
// ErrMissingTitle is returned if a title was not extracted from the installer.
ErrMissingTitle = errors.New("missing title")
)

// Generate generates the "trigger policy" from the metadata of a software package.
func Generate(metadata InstallerMetadata) (*PolicyData, error) {
switch {
case metadata.Title == "":
return nil, ErrMissingTitle
case metadata.Extension != "pkg" && metadata.Extension != "msi" && metadata.Extension != "deb" && metadata.Extension != "rpm":
return nil, ErrExtensionNotSupported
case metadata.Extension == "pkg" && metadata.BundleIdentifier == "":
return nil, ErrMissingBundleIdentifier
case metadata.Extension == "msi" && (len(metadata.PackageIDs) == 0 || metadata.PackageIDs[0] == ""):
return nil, ErrMissingProductCode
}

name := fmt.Sprintf("[Install software] %s (%s)", metadata.Title, metadata.Extension)

description := fmt.Sprintf("Policy triggers automatic install of %s on each host that's missing this software.", metadata.Title)
if metadata.Extension == "deb" || metadata.Extension == "rpm" {
basedPrefix := "RPM"
if metadata.Extension == "rpm" {
basedPrefix = "Debian"
}
description += fmt.Sprintf(
"\nSoftware won't be installed on Linux hosts with %s-based distributions because this policy's query is written to always pass on these hosts.",
basedPrefix,
)
}

switch metadata.Extension {
case "pkg":
return &PolicyData{
Name: name,
Query: fmt.Sprintf("SELECT 1 FROM apps WHERE bundle_identifier = '%s';", metadata.BundleIdentifier),
Platform: "darwin",
Description: description,
}, nil
case "msi":
return &PolicyData{
Name: name,
Query: fmt.Sprintf("SELECT 1 FROM programs WHERE identifying_number = '%s';", metadata.PackageIDs[0]),
Platform: "windows",
Description: description,
}, nil
case "deb":
return &PolicyData{
Name: name,
Query: fmt.Sprintf(
// First inner SELECT will mark the policies as successful on non-DEB-based hosts.
`SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM deb_packages) = 0
) OR EXISTS (
SELECT 1 FROM deb_packages WHERE name = '%s'
);`, metadata.Title,
),
Platform: "linux",
Description: description,
}, nil
case "rpm":
return &PolicyData{
Name: name,
Query: fmt.Sprintf(
// First inner SELECT will mark the policies as successful on non-RPM-based hosts.
`SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM rpm_packages) = 0
) OR EXISTS (
SELECT 1 FROM rpm_packages WHERE name = '%s'
);`, metadata.Title),
Platform: "linux",
Description: description,
}, nil
default:
return nil, ErrExtensionNotSupported
}
}
108 changes: 108 additions & 0 deletions pkg/automatic_policy/automatic_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package automatic_policy

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGenerateErrors(t *testing.T) {
_, err := Generate(InstallerMetadata{
Title: "Foobar",
Extension: "exe",
BundleIdentifier: "",
PackageIDs: []string{"Foobar"},
})
require.ErrorIs(t, err, ErrExtensionNotSupported)

_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingProductCode)
_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{},
})
require.ErrorIs(t, err, ErrMissingProductCode)

_, err = Generate(InstallerMetadata{
Title: "Foobar",
Extension: "pkg",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingBundleIdentifier)

_, err = Generate(InstallerMetadata{
Title: "",
Extension: "deb",
BundleIdentifier: "",
PackageIDs: []string{""},
})
require.ErrorIs(t, err, ErrMissingTitle)
}

func TestGenerate(t *testing.T) {
policyData, err := Generate(InstallerMetadata{
Title: "Foobar",
Extension: "pkg",
BundleIdentifier: "com.foo.bar",
PackageIDs: []string{"com.foo.bar"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Foobar (pkg)", policyData.Name)
require.Equal(t, "Policy triggers automatic install of Foobar on each host that's missing this software.", policyData.Description)
require.Equal(t, "darwin", policyData.Platform)
require.Equal(t, "SELECT 1 FROM apps WHERE bundle_identifier = 'com.foo.bar';", policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Barfoo",
Extension: "msi",
BundleIdentifier: "",
PackageIDs: []string{"foo"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Barfoo (msi)", policyData.Name)
require.Equal(t, "Policy triggers automatic install of Barfoo on each host that's missing this software.", policyData.Description)
require.Equal(t, "windows", policyData.Platform)
require.Equal(t, "SELECT 1 FROM programs WHERE identifying_number = 'foo';", policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Zoobar",
Extension: "deb",
BundleIdentifier: "",
PackageIDs: []string{"Zoobar"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Zoobar (deb)", policyData.Name)
require.Equal(t, `Policy triggers automatic install of Zoobar on each host that's missing this software.
Software won't be installed on Linux hosts with RPM-based distributions because this policy's query is written to always pass on these hosts.`, policyData.Description)
require.Equal(t, "linux", policyData.Platform)
require.Equal(t, `SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM deb_packages) = 0
) OR EXISTS (
SELECT 1 FROM deb_packages WHERE name = 'Zoobar'
);`, policyData.Query)

policyData, err = Generate(InstallerMetadata{
Title: "Barzoo",
Extension: "rpm",
BundleIdentifier: "",
PackageIDs: []string{"Barzoo"},
})
require.NoError(t, err)
require.Equal(t, "[Install software] Barzoo (rpm)", policyData.Name)
require.Equal(t, `Policy triggers automatic install of Barzoo on each host that's missing this software.
Software won't be installed on Linux hosts with Debian-based distributions because this policy's query is written to always pass on these hosts.`, policyData.Description)
require.Equal(t, "linux", policyData.Platform)
require.Equal(t, `SELECT 1 WHERE EXISTS (
SELECT 1 WHERE (SELECT COUNT(*) FROM rpm_packages) = 0
) OR EXISTS (
SELECT 1 FROM rpm_packages WHERE name = 'Barzoo'
);`, policyData.Query)
}
22 changes: 13 additions & 9 deletions server/datastore/mysql/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo
}

if p.TeamID != nil {
if err := ds.assertTeamMatches(ctx, *p.TeamID, p.SoftwareInstallerID, p.ScriptID); err != nil {
if err := assertTeamMatches(ctx, ds.writer(ctx), *p.TeamID, p.SoftwareInstallerID, p.ScriptID); err != nil {
return ctxerr.Wrap(ctx, err, "save policy")
}
}
Expand Down Expand Up @@ -185,10 +185,10 @@ var (
errMismatchedScriptTeam = &fleet.BadRequestError{Message: "script is associated with a different team"}
)

func (ds *Datastore) assertTeamMatches(ctx context.Context, teamID uint, softwareInstallerID *uint, scriptID *uint) error {
func assertTeamMatches(ctx context.Context, db sqlx.QueryerContext, teamID uint, softwareInstallerID *uint, scriptID *uint) error {
if softwareInstallerID != nil {
var softwareInstallerTeamID uint
err := sqlx.GetContext(ctx, ds.reader(ctx), &softwareInstallerTeamID, "SELECT global_or_team_id FROM software_installers WHERE id = ?", softwareInstallerID)
err := sqlx.GetContext(ctx, db, &softwareInstallerTeamID, "SELECT global_or_team_id FROM software_installers WHERE id = ?", softwareInstallerID)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -202,7 +202,7 @@ func (ds *Datastore) assertTeamMatches(ctx context.Context, teamID uint, softwar

if scriptID != nil {
var scriptTeamID uint
err := sqlx.GetContext(ctx, ds.reader(ctx), &scriptTeamID, "SELECT global_or_team_id FROM scripts WHERE id = ?", scriptID)
err := sqlx.GetContext(ctx, db, &scriptTeamID, "SELECT global_or_team_id FROM scripts WHERE id = ?", scriptID)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand Down Expand Up @@ -647,8 +647,12 @@ func (ds *Datastore) PolicyQueriesForHost(ctx context.Context, host *fleet.Host)
}

func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) {
return newTeamPolicy(ctx, ds.writer(ctx), teamID, authorID, args)
}

func newTeamPolicy(ctx context.Context, db sqlx.ExtContext, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) {
if args.QueryID != nil {
q, err := ds.Query(ctx, *args.QueryID)
q, err := query(ctx, db, *args.QueryID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "fetching query from id")
}
Expand All @@ -659,7 +663,7 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
// Check team exists.
if teamID > 0 {
var ok bool
err := ds.writer(ctx).GetContext(ctx, &ok, `SELECT COUNT(*) = 1 FROM teams WHERE id = ?`, teamID)
err := sqlx.GetContext(ctx, db, &ok, `SELECT COUNT(*) = 1 FROM teams WHERE id = ?`, teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get team id")
}
Expand All @@ -671,11 +675,11 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
// We must normalize the name for full Unicode support (Unicode equivalence).
nameUnicode := norm.NFC.String(args.Name)

if err := ds.assertTeamMatches(ctx, teamID, args.SoftwareInstallerID, args.ScriptID); err != nil {
if err := assertTeamMatches(ctx, db, teamID, args.SoftwareInstallerID, args.ScriptID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "create team policy")
}

res, err := ds.writer(ctx).ExecContext(ctx,
res, err := db.ExecContext(ctx,
fmt.Sprintf(
`INSERT INTO policies (name, query, description, team_id, resolution, author_id, platforms, critical, calendar_events_enabled, software_installer_id, script_id, checksum) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, %s)`,
policiesChecksumComputedColumn(),
Expand All @@ -695,7 +699,7 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting last id after inserting policy")
}
return policyDB(ctx, ds.writer(ctx), uint(lastIdInt64), &teamID) //nolint:gosec // dismiss G115
return policyDB(ctx, db, uint(lastIdInt64), &teamID) //nolint:gosec // dismiss G115
}

func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) {
Expand Down
Loading
Loading