diff --git a/cmd/gcs-sidecar/main.go b/cmd/gcs-sidecar/main.go index 60f23f6928..04952899e0 100644 --- a/cmd/gcs-sidecar/main.go +++ b/cmd/gcs-sidecar/main.go @@ -15,7 +15,6 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs/prot" shimlog "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" - "github.com/Microsoft/hcsshim/internal/pspdriver" "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/sirupsen/logrus" "go.opencensus.io/trace" @@ -214,7 +213,7 @@ func main() { return } - if err := pspdriver.StartPSPDriver(ctx); err != nil { + if err := securitypolicy.StartPSPDriver(ctx); err != nil { // When error happens, pspdriver.GetPspDriverError() returns true. // In that case, gcs-sidecar should keep the initial "deny" policy // and reject all requests from the host. diff --git a/cmd/gcs/main.go b/cmd/gcs/main.go index 36ae1991b6..a17f3ae232 100644 --- a/cmd/gcs/main.go +++ b/cmd/gcs/main.go @@ -31,6 +31,7 @@ import ( "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/version" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" "github.com/Microsoft/hcsshim/pkg/securitypolicy" ) @@ -362,6 +363,10 @@ func main() { b := bridge.Bridge{ Handler: mux, EnableV4: *v4, + + // For confidential containers, we protect ourselves against attacks caused + // by concurrent modifications, by processing one request at a time. + ForceSequential: amdsevsnp.IsSNP(), } h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter) // Initialize virtual pod support in the host diff --git a/internal/gcs-sidecar/bridge.go b/internal/gcs-sidecar/bridge.go index b7bc2f9522..87472cc4a2 100644 --- a/internal/gcs-sidecar/bridge.go +++ b/internal/gcs-sidecar/bridge.go @@ -47,9 +47,6 @@ type Bridge struct { // and send responses back to hcsshim respectively. sendToGCSCh chan request sendToShimCh chan bridgeResponse - - // logging target - logWriter io.Writer } // SequenceID is used to correlate requests and responses. @@ -81,7 +78,7 @@ type request struct { } func NewBridge(shimConn io.ReadWriteCloser, inboxGCSConn io.ReadWriteCloser, initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Bridge { - hostState := NewHost(initialEnforcer) + hostState := NewHost(initialEnforcer, logWriter) return &Bridge{ pending: make(map[sequenceID]chan *prot.ContainerExecuteProcessResponse), rpcHandlerList: make(map[prot.RPCProc]HandlerFunc), @@ -90,7 +87,6 @@ func NewBridge(shimConn io.ReadWriteCloser, inboxGCSConn io.ReadWriteCloser, ini inboxGCSConn: inboxGCSConn, sendToGCSCh: make(chan request), sendToShimCh: make(chan bridgeResponse), - logWriter: logWriter, } } diff --git a/internal/gcs-sidecar/handlers.go b/internal/gcs-sidecar/handlers.go index 0ea47b79e3..f46dc8b985 100644 --- a/internal/gcs-sidecar/handlers.go +++ b/internal/gcs-sidecar/handlers.go @@ -18,9 +18,11 @@ import ( hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" + oci "github.com/Microsoft/hcsshim/internal/oci" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/internal/windevice" + "github.com/Microsoft/hcsshim/pkg/annotations" "github.com/Microsoft/hcsshim/pkg/cimfs" "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/pkg/errors" @@ -80,15 +82,12 @@ func (b *Bridge) createContainer(req *request) (err error) { user := securitypolicy.IDName{ Name: spec.Process.User.Username, } - _, _, _, err := b.hostState.securityPolicyEnforcer.EnforceCreateContainerPolicyV2(req.ctx, containerID, spec.Process.Args, spec.Process.Env, spec.Process.Cwd, spec.Mounts, user, nil) + _, _, _, err := b.hostState.securityOptions.PolicyEnforcer.EnforceCreateContainerPolicyV2(req.ctx, containerID, spec.Process.Args, spec.Process.Env, spec.Process.Cwd, spec.Mounts, user, nil) if err != nil { return fmt.Errorf("CreateContainer operation is denied by policy: %w", err) } - if err := b.hostState.SetupSecurityContextDir(ctx, &spec); err != nil { - return err - } commandLine := len(spec.Process.Args) > 0 c := &Container{ id: containerID, @@ -109,6 +108,13 @@ func (b *Bridge) createContainer(req *request) (err error) { } }(err) + if oci.ParseAnnotationsBool(ctx, spec.Annotations, annotations.WCOWSecurityPolicyEnv, true) { + if err := b.hostState.securityOptions.WriteSecurityContextDir(&spec); err != nil { + return fmt.Errorf("failed to write security context dir: %w", err) + } + cwcowHostedSystemConfig.Spec = spec + } + // Strip the spec field hostedSystemBytes, err := json.Marshal(cwcowHostedSystem) @@ -151,20 +157,6 @@ func (b *Bridge) createContainer(req *request) (err error) { return nil } -func writeFileInDir(dir string, filename string, data []byte, perm os.FileMode) error { - st, err := os.Stat(dir) - if err != nil { - return err - } - - if !st.IsDir() { - return fmt.Errorf("not a directory %q", dir) - } - - targetFilename := filepath.Join(dir, filename) - return os.WriteFile(targetFilename, data, perm) -} - // processParamEnvToOCIEnv converts an Environment field from ProcessParameters // (a map from environment variable to value) into an array of environment // variable assignments (where each is in the form "=") which @@ -203,7 +195,7 @@ func (b *Bridge) shutdownGraceful(req *request) (err error) { return fmt.Errorf("failed to unmarshal shutdownGraceful: %w", err) } - err = b.hostState.securityPolicyEnforcer.EnforceShutdownContainerPolicy(req.ctx, r.ContainerID) + err = b.hostState.securityOptions.PolicyEnforcer.EnforceShutdownContainerPolicy(req.ctx, r.ContainerID) if err != nil { return fmt.Errorf("rpcShudownGraceful operation not allowed: %w", err) } @@ -247,7 +239,7 @@ func (b *Bridge) executeProcess(req *request) (err error) { if containerID == UVMContainerID { log.G(req.ctx).Tracef("Enforcing policy on external exec process") - _, _, err := b.hostState.securityPolicyEnforcer.EnforceExecExternalProcessPolicy( + _, _, err := b.hostState.securityOptions.PolicyEnforcer.EnforceExecExternalProcessPolicy( req.ctx, commandLine, processParamEnvToOCIEnv(processParams.Environment), @@ -279,7 +271,7 @@ func (b *Bridge) executeProcess(req *request) (err error) { Name: processParams.User, } log.G(req.ctx).Tracef("Enforcing policy on exec in container") - _, _, _, err = b.hostState.securityPolicyEnforcer. + _, _, _, err = b.hostState.securityOptions.PolicyEnforcer. EnforceExecInContainerPolicyV2( req.ctx, containerID, @@ -385,7 +377,7 @@ func (b *Bridge) signalProcess(req *request) (err error) { WindowsSignal: wcowOptions.Signal, WindowsCommand: commandLine, } - err = b.hostState.securityPolicyEnforcer.EnforceSignalContainerProcessPolicyV2(req.ctx, containerID, opts) + err = b.hostState.securityOptions.PolicyEnforcer.EnforceSignalContainerProcessPolicyV2(req.ctx, containerID, opts) if err != nil { return err } @@ -414,7 +406,7 @@ func (b *Bridge) getProperties(req *request) (err error) { defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - if err := b.hostState.securityPolicyEnforcer.EnforceGetPropertiesPolicy(req.ctx); err != nil { + if err := b.hostState.securityOptions.PolicyEnforcer.EnforceGetPropertiesPolicy(req.ctx); err != nil { return errors.Wrapf(err, "get properties denied due to policy") } @@ -552,11 +544,14 @@ func (b *Bridge) modifySettings(req *request) (err error) { log.G(ctx).Tracef("hcsschema.MappedDirectory { %v }", settings) case guestresource.ResourceTypeSecurityPolicy: - securityPolicyRequest := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWConfidentialOptions) + securityPolicyRequest := modifyGuestSettingsRequest.Settings.(*guestresource.ConfidentialOptions) log.G(ctx).Tracef("WCOWConfidentialOptions: { %v}", securityPolicyRequest) - err := b.hostState.SetWCOWConfidentialUVMOptions(req.ctx, securityPolicyRequest, b.logWriter) + err := b.hostState.securityOptions.SetConfidentialOptions(ctx, + securityPolicyRequest.EnforcerType, + securityPolicyRequest.EncodedSecurityPolicy, + securityPolicyRequest.EncodedUVMReference) if err != nil { - return errors.Wrap(err, "error creating enforcer") + return errors.Wrap(err, "Failed to set Confidentia UVM Options") } // Send response back to shim resp := &prot.ResponseBase{ @@ -569,12 +564,11 @@ func (b *Bridge) modifySettings(req *request) (err error) { } return nil case guestresource.ResourceTypePolicyFragment: - //Note: Reusing the same type LCOWSecurityPolicyFragment for CWCOW. - r, ok := modifyGuestSettingsRequest.Settings.(*guestresource.LCOWSecurityPolicyFragment) + r, ok := modifyGuestSettingsRequest.Settings.(*guestresource.SecurityPolicyFragment) if !ok { - return errors.New("the request settings are not of type LCOWSecurityPolicyFragment") + return errors.New("the request settings are not of type SecurityPolicyFragment") } - return b.hostState.InjectFragment(ctx, r) + return b.hostState.securityOptions.InjectFragment(ctx, r) case guestresource.ResourceTypeWCOWBlockCims: // This is request to mount the merged cim at given volumeGUID if modifyGuestSettingsRequest.RequestType == guestrequest.RequestTypeRemove { @@ -588,7 +582,6 @@ func (b *Bridge) modifySettings(req *request) (err error) { // The block device takes some time to show up. Wait for a few seconds. time.Sleep(2 * time.Second) - //TODO(Mahati) : test and verify CIM hashes var layerCIMs []*cimfs.BlockCIM layerHashes := make([]string, len(wcowBlockCimMounts.BlockCIMs)) layerDigests := make([][]byte, len(wcowBlockCimMounts.BlockCIMs)) @@ -624,7 +617,7 @@ func (b *Bridge) modifySettings(req *request) (err error) { hashesToVerify = layerHashes[1:] } - err := b.hostState.securityPolicyEnforcer.EnforceVerifiedCIMsPolicy(req.ctx, containerID, hashesToVerify) + err := b.hostState.securityOptions.PolicyEnforcer.EnforceVerifiedCIMsPolicy(req.ctx, containerID, hashesToVerify) if err != nil { return errors.Wrap(err, "CIM mount is denied by policy") } @@ -664,7 +657,7 @@ func (b *Bridge) modifySettings(req *request) (err error) { containerID, settings.CombinedLayers.ContainerRootPath, settings.CombinedLayers.Layers, settings.CombinedLayers.ScratchPath) //Since unencrypted scratch is not an option, always pass true - if err := b.hostState.securityPolicyEnforcer.EnforceScratchMountPolicy(ctx, settings.CombinedLayers.ContainerRootPath, true); err != nil { + if err := b.hostState.securityOptions.PolicyEnforcer.EnforceScratchMountPolicy(ctx, settings.CombinedLayers.ContainerRootPath, true); err != nil { return fmt.Errorf("scratch mounting denied by policy: %w", err) } // The following two folders are expected to be present in the scratch. diff --git a/internal/gcs-sidecar/host.go b/internal/gcs-sidecar/host.go index 3f51d8a7e0..86fc98a30f 100644 --- a/internal/gcs-sidecar/host.go +++ b/internal/gcs-sidecar/host.go @@ -5,7 +5,6 @@ package bridge import ( "context" - "fmt" "io" "sync" @@ -13,30 +12,20 @@ import ( hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" - oci "github.com/Microsoft/hcsshim/internal/oci" - "github.com/Microsoft/hcsshim/internal/protocol/guestresource" - "github.com/Microsoft/hcsshim/internal/pspdriver" - "github.com/Microsoft/hcsshim/pkg/annotations" "github.com/Microsoft/hcsshim/pkg/securitypolicy" - specs "github.com/opencontainers/runtime-spec/specs-go" - "github.com/pkg/errors" + oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" ) type Host struct { + securityOptions *securitypolicy.SecurityOptions containersMutex sync.Mutex containers map[string]*Container - - // state required for the security policy enforcement - policyMutex sync.Mutex - securityPolicyEnforcer securitypolicy.SecurityPolicyEnforcer - securityPolicyEnforcerSet bool - uvmReferenceInfo string } type Container struct { id string - spec specs.Spec + spec oci.Spec processesMutex sync.Mutex processes map[uint32]*containerProcess commandLine bool @@ -52,137 +41,17 @@ type containerProcess struct { pid uint32 } -func NewHost(initialEnforcer securitypolicy.SecurityPolicyEnforcer) *Host { - return &Host{ - containers: make(map[string]*Container), - securityPolicyEnforcer: initialEnforcer, - securityPolicyEnforcerSet: false, - } -} - -// Write security policy, signed UVM reference and host AMD certificate to -// container's rootfs, so that application and sidecar containers can have -// access to it. The security policy is required by containers which need to -// extract init-time claims found in the security policy. The directory path -// containing the files is exposed via UVM_SECURITY_CONTEXT_DIR env var. -// It may be an error to have a security policy but not expose it to the -// container as in that case it can never be checked as correct by a verifier. -func (h *Host) SetupSecurityContextDir(ctx context.Context, spec *specs.Spec) error { - if oci.ParseAnnotationsBool(ctx, spec.Annotations, annotations.WCOWSecurityPolicyEnv, true) { - encodedPolicy := h.securityPolicyEnforcer.EncodedSecurityPolicy() - hostAMDCert := spec.Annotations[annotations.WCOWHostAMDCertificate] - if len(encodedPolicy) > 0 || len(hostAMDCert) > 0 || len(h.uvmReferenceInfo) > 0 { - // Use os.MkdirTemp to make sure that the directory is unique. - securityContextDir, err := os.MkdirTemp(spec.Root.Path, securitypolicy.SecurityContextDirTemplate) - if err != nil { - return fmt.Errorf("failed to create security context directory: %w", err) - } - // Make sure that files inside directory are readable - if err := os.Chmod(securityContextDir, 0755); err != nil { - return fmt.Errorf("failed to chmod security context directory: %w", err) - } - - if len(encodedPolicy) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.PolicyFilename, []byte(encodedPolicy), 0777); err != nil { - return fmt.Errorf("failed to write security policy: %w", err) - } - } - if len(h.uvmReferenceInfo) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.ReferenceInfoFilename, []byte(h.uvmReferenceInfo), 0777); err != nil { - return fmt.Errorf("failed to write UVM reference info: %w", err) - } - } - - if len(hostAMDCert) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.HostAMDCertFilename, []byte(hostAMDCert), 0777); err != nil { - return fmt.Errorf("failed to write host AMD certificate: %w", err) - } - } - - containerCtxDir := fmt.Sprintf("/%s", filepath.Base(securityContextDir)) - secCtxEnv := fmt.Sprintf("UVM_SECURITY_CONTEXT_DIR=%s", containerCtxDir) - spec.Process.Env = append(spec.Process.Env, secCtxEnv) - } - } - return nil -} - -// InjectFragment extends current security policy with additional constraints -// from the incoming fragment. Note that it is base64 encoded over the bridge/ -// -// There are three checking steps: -// 1 - Unpack the cose document and check it was actually signed with the cert -// chain inside its header -// 2 - Check that the issuer field did:x509 identifier is for that cert chain -// (ie fingerprint of a non leaf cert and the subject matches the leaf cert) -// 3 - Check that this issuer/feed match the requirement of the user provided -// security policy (done in the regoby LoadFragment) -func (h *Host) InjectFragment(ctx context.Context, fragment *guestresource.LCOWSecurityPolicyFragment) (err error) { - log.G(ctx).WithField("fragment", fmt.Sprintf("%+v", fragment)).Debug("GCS Host.InjectFragment") - issuer, feed, payloadString, err := securitypolicy.ExtractAndVerifyFragment(ctx, fragment) - if err != nil { - return err - } - // now offer the payload fragment to the policy - err = h.securityPolicyEnforcer.LoadFragment(ctx, issuer, feed, payloadString) - if err != nil { - return fmt.Errorf("error loading security policy fragment: %w", err) - } - return nil -} - -func (h *Host) SetWCOWConfidentialUVMOptions(ctx context.Context, securityPolicyRequest *guestresource.WCOWConfidentialOptions, logWriter io.Writer) error { - h.policyMutex.Lock() - defer h.policyMutex.Unlock() - - if h.securityPolicyEnforcerSet { - return errors.New("security policy has already been set") - } - - if err := pspdriver.GetPspDriverError(); err != nil { - // For this case gcs-sidecar will keep initial deny policy. - return errors.Wrapf(err, "an error occurred while using PSP driver") - } - - // Fetch report and validate host_data - hostData, err := securitypolicy.NewSecurityPolicyDigest(securityPolicyRequest.EncodedSecurityPolicy) - if err != nil { - return err - } - - if err := pspdriver.ValidateHostData(ctx, hostData[:]); err != nil { - // For this case gcs-sidecar will keep initial deny policy. - return err - } - - // This limit ensures messages are below the character truncation limit that - // can be imposed by an orchestrator - maxErrorMessageLength := 3 * 1024 - - // Initialize security policy enforcer for a given enforcer type and - // encoded security policy. - p, err := securitypolicy.CreateSecurityPolicyEnforcer( - securityPolicyRequest.EnforcerType, - securityPolicyRequest.EncodedSecurityPolicy, - DefaultCRIMounts(), - DefaultCRIPrivilegedMounts(), - maxErrorMessageLength, +func NewHost(initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Host { + securityPolicyOptions := securitypolicy.NewSecurityOptions( + initialEnforcer, + false, + "", + logWriter, ) - if err != nil { - return fmt.Errorf("error creating security policy enforcer: %w", err) - } - - if err = p.EnforceRuntimeLoggingPolicy(ctx); err == nil { - logrus.SetOutput(logWriter) - } else { - logrus.SetOutput(io.Discard) + return &Host{ + containers: make(map[string]*Container), + securityOptions: securityPolicyOptions, } - - h.securityPolicyEnforcer = p - h.securityPolicyEnforcerSet = true - h.uvmReferenceInfo = securityPolicyRequest.EncodedUVMReference - - return nil } func (h *Host) AddContainer(ctx context.Context, id string, c *Container) error { diff --git a/internal/gcs-sidecar/policy.go b/internal/gcs-sidecar/policy.go deleted file mode 100644 index 13b96ce64d..0000000000 --- a/internal/gcs-sidecar/policy.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build windows -// +build windows - -package bridge - -import ( - oci "github.com/opencontainers/runtime-spec/specs-go" -) - -// DefaultCRIMounts returns default mounts added to windows spec by containerD. -func DefaultCRIMounts() []oci.Mount { - return []oci.Mount{} -} - -// DefaultCRIPrivilegedMounts returns a slice of mounts which are added to the -// windows container spec when a container runs in a privileged mode. -func DefaultCRIPrivilegedMounts() []oci.Mount { - return []oci.Mount{} -} diff --git a/internal/gcs-sidecar/uvm.go b/internal/gcs-sidecar/uvm.go index 9578e0273c..b3a7792fd4 100644 --- a/internal/gcs-sidecar/uvm.go +++ b/internal/gcs-sidecar/uvm.go @@ -92,7 +92,7 @@ func unmarshalContainerModifySettings(req *request) (_ *prot.ContainerModifySett modifyGuestSettingsRequest.Settings = settings case guestresource.ResourceTypeSecurityPolicy: - securityPolicyRequest := &guestresource.WCOWConfidentialOptions{} + securityPolicyRequest := &guestresource.ConfidentialOptions{} if err := commonutils.UnmarshalJSONWithHresult(rawGuestRequest, securityPolicyRequest); err != nil { return nil, fmt.Errorf("invalid ResourceTypeSecurityPolicy request: %w", err) } diff --git a/internal/gcs/unrecoverable_error.go b/internal/gcs/unrecoverable_error.go new file mode 100644 index 0000000000..dbb7240266 --- /dev/null +++ b/internal/gcs/unrecoverable_error.go @@ -0,0 +1,49 @@ +//go:build linux +// +build linux + +package gcs + +import ( + "context" + "fmt" + "os" + "runtime" + "time" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" + "github.com/sirupsen/logrus" +) + +// UnrecoverableError logs the error and then puts the current thread into an +// infinite sleep loop. This is to be used instead of panicking, as the +// behaviour of GCS panics is unpredictable. This function can be extended to, +// for example, try to shutdown the VM cleanly. +func UnrecoverableError(err error) { + buf := make([]byte, 300*(1<<10)) + stackSize := runtime.Stack(buf, true) + stackTrace := string(buf[:stackSize]) + + errPrint := fmt.Sprintf( + "Unrecoverable error in GCS: %v\n%s", + err, stackTrace, + ) + isSnp := amdsevsnp.IsSNP() + if isSnp { + errPrint += "\nThis thread will now enter an infinite loop." + } + log.G(context.Background()).WithError(err).Logf( + logrus.FatalLevel, + "%s", + errPrint, + ) + + if !isSnp { + panic("Unrecoverable error in GCS: " + err.Error()) + } else { + fmt.Fprintf(os.Stderr, "%s\n", errPrint) + for { + time.Sleep(time.Hour) + } + } +} diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go index 4ea03ed104..875def5809 100644 --- a/internal/guest/bridge/bridge.go +++ b/internal/guest/bridge/bridge.go @@ -177,6 +177,10 @@ type Bridge struct { Handler Handler // EnableV4 enables the v4+ bridge and the schema v2+ interfaces. EnableV4 bool + // Setting ForceSequential to true will force the bridge to only process one + // request at a time, except for certain long-running operations (as defined + // in asyncMessages). + ForceSequential bool // responseChan is the response channel used for both request/response // and publish notification workflows. @@ -191,6 +195,14 @@ type Bridge struct { protVer prot.ProtocolVersion } +// Messages that will be processed asynchronously even in sequential mode. Note +// that in sequential mode, these messages will still wait for any in-progress +// non-async messages to be handled before they are processed, but once they are +// "acknowledged", the rest will be done asynchronously. +var alwaysAsyncMessages map[prot.MessageIdentifier]bool = map[prot.MessageIdentifier]bool{ + prot.ComputeSystemWaitForProcessV1: true, +} + // AssignHandlers creates and assigns the appropriate bridge // events to be listen for and intercepted on `mux` before forwarding // to `gcs` for handling. @@ -238,6 +250,10 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser defer close(requestErrChan) defer bridgeIn.Close() + if b.ForceSequential { + log.G(context.Background()).Info("bridge: ForceSequential enabled") + } + // Receive bridge requests and schedule them to be processed. go func() { var recverr error @@ -340,30 +356,36 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser }() // Process each bridge request async and create the response writer. go func() { - for req := range requestChan { - go func(r *Request) { - br := bridgeResponse{ - ctx: r.Context, - header: &prot.MessageHeader{ - Type: prot.GetResponseIdentifier(r.Header.Type), - ID: r.Header.ID, - }, - } - resp, err := b.Handler.ServeMsg(r) - if resp == nil { - resp = &prot.MessageResponseBase{} - } - resp.Base().ActivityID = r.ActivityID - if err != nil { - span := trace.FromContext(r.Context) - if span != nil { - oc.SetSpanStatus(span, err) - } - setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + doOneRequest := func(r *Request) { + br := bridgeResponse{ + ctx: r.Context, + header: &prot.MessageHeader{ + Type: prot.GetResponseIdentifier(r.Header.Type), + ID: r.Header.ID, + }, + } + resp, err := b.Handler.ServeMsg(r) + if resp == nil { + resp = &prot.MessageResponseBase{} + } + resp.Base().ActivityID = r.ActivityID + if err != nil { + span := trace.FromContext(r.Context) + if span != nil { + oc.SetSpanStatus(span, err) } - br.response = resp - b.responseChan <- br - }(req) + setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + } + br.response = resp + b.responseChan <- br + } + + for req := range requestChan { + if b.ForceSequential && !alwaysAsyncMessages[req.Header.Type] { + runSequentialRequestHandler(req, doOneRequest) + } else { + go doOneRequest(req) + } } }() // Process each bridge response sync. This channel is for request/response and publish workflows. @@ -423,6 +445,32 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser } } +// Do handleFn(r), but prints a warning if handleFn does not, or takes too long +// to return. +func runSequentialRequestHandler(r *Request, handleFn func(*Request)) { + // Note that this is only a context used for triggering the blockage + // warning, the request processing still uses r.Context. We don't want to + // cancel the request handling itself when we reach the 5s timeout. + timeoutCtx, cancel := context.WithTimeout(r.Context, 5*time.Second) + go func() { + <-timeoutCtx.Done() + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + log.G(timeoutCtx).WithFields(logrus.Fields{ + // We want to log those even though we're providing r.Context, since if + // the request never finishes the span end log will never get written, + // and we may therefore not be able to find out about the following info + // otherwise: + "message-type": r.Header.Type.String(), + "message-id": r.Header.ID, + "activity-id": r.ActivityID, + "container-id": r.ContainerID, + }).Warnf("bridge: request processing thread in sequential mode blocked on the current request for more than 5 seconds") + } + }() + defer cancel() + handleFn(r) +} + // PublishNotification writes a specific notification to the bridge. func (b *Bridge) PublishNotification(n *prot.ContainerNotification) { ctx, span := oc.StartSpan(context.Background(), diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index 800094e549..2f105ef5b6 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -467,16 +467,10 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message) } - c, err := b.hostState.GetCreatedContainer(request.ContainerID) + err = b.hostState.DeleteContainerState(ctx, request.ContainerID) if err != nil { return nil, err } - // remove container state regardless of delete's success - defer b.hostState.RemoveContainer(request.ContainerID) - - if err := c.Delete(ctx); err != nil { - return nil, err - } return &prot.MessageResponseBase{}, nil } diff --git a/internal/guest/network/network.go b/internal/guest/network/network.go index 68f7c1bef1..4cac108b03 100644 --- a/internal/guest/network/network.go +++ b/internal/guest/network/network.go @@ -9,6 +9,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" "time" @@ -32,6 +33,18 @@ var ( // maxDNSSearches is limited to 6 in `man 5 resolv.conf` const maxDNSSearches = 6 +var validHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]{0,255}$`) + +// Check that the hostname is safe. This function is less strict than +// technically allowed, but ensures that when the hostname is inserted to +// /etc/hosts, it cannot lead to injection attacks. +func ValidateHostname(hostname string) error { + if !validHostnameRegex.MatchString(hostname) { + return errors.Errorf("hostname %q invalid: must match %s", hostname, validHostnameRegex.String()) + } + return nil +} + // GenerateEtcHostsContent generates a /etc/hosts file based on `hostname`. func GenerateEtcHostsContent(ctx context.Context, hostname string) string { _, span := oc.StartSpan(ctx, "network::GenerateEtcHostsContent") diff --git a/internal/guest/network/network_test.go b/internal/guest/network/network_test.go index 4ac6ff1f10..bb19db6974 100644 --- a/internal/guest/network/network_test.go +++ b/internal/guest/network/network_test.go @@ -7,6 +7,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" ) @@ -122,6 +123,40 @@ func Test_MergeValues(t *testing.T) { } } +func Test_ValidateHostname(t *testing.T) { + validNames := []string{ + "localhost", + "my-hostname", + "my.hostname", + "my-host-name123", + "_underscores.are.allowed.too", + "", // Allow not specifying a hostname + } + + invalidNames := []string{ + "localhost\n13.104.0.1 ip6-localhost ip6-loopback localhost", + "localhost\n2603:1000::1 ip6-localhost ip6-loopback", + "hello@microsoft.com", + "has space", + "has,comma", + "\x00", + "a\nb", + strings.Repeat("a", 1000), + } + + for _, n := range validNames { + if err := ValidateHostname(n); err != nil { + t.Fatalf("expected %q to be valid, got: %v", n, err) + } + } + + for _, n := range invalidNames { + if err := ValidateHostname(n); err == nil { + t.Fatalf("expected %q to be invalid, but got nil error", n) + } + } +} + func Test_GenerateEtcHostsContent(t *testing.T) { type testcase struct { name string diff --git a/internal/guest/policy/default.go b/internal/guest/policy/default.go deleted file mode 100644 index 98c909fb35..0000000000 --- a/internal/guest/policy/default.go +++ /dev/null @@ -1,98 +0,0 @@ -//go:build linux -// +build linux - -package policy - -import ( - oci "github.com/opencontainers/runtime-spec/specs-go" - - specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" - "github.com/Microsoft/hcsshim/pkg/securitypolicy" -) - -func ExtendPolicyWithNetworkingMounts(sandboxID string, enforcer securitypolicy.SecurityPolicyEnforcer, spec *oci.Spec) error { - roSpec := &oci.Spec{ - Root: spec.Root, - } - networkingMounts := specGuest.GenerateWorkloadContainerNetworkMounts(sandboxID, roSpec) - if err := enforcer.ExtendDefaultMounts(networkingMounts); err != nil { - return err - } - return nil -} - -// DefaultCRIMounts returns default mounts added to linux spec by containerD. -func DefaultCRIMounts() []oci.Mount { - return []oci.Mount{ - { - Destination: "/proc", - Type: "proc", - Source: "proc", - Options: []string{"nosuid", "noexec", "nodev"}, - }, - { - Destination: "/dev", - Type: "tmpfs", - Source: "tmpfs", - Options: []string{"nosuid", "strictatime", "mode=755", "size=65536k"}, - }, - { - Destination: "/dev/pts", - Type: "devpts", - Source: "devpts", - Options: []string{"nosuid", "noexec", "newinstance", "ptmxmode=0666", "mode=0620", "gid=5"}, - }, - { - Destination: "/dev/shm", - Type: "tmpfs", - Source: "shm", - Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, - }, - { - Destination: "/dev/mqueue", - Type: "mqueue", - Source: "mqueue", - Options: []string{"nosuid", "noexec", "nodev"}, - }, - { - Destination: "/sys", - Type: "sysfs", - Source: "sysfs", - Options: []string{"nosuid", "noexec", "nodev", "ro"}, - }, - { - Destination: "/run", - Type: "tmpfs", - Source: "tmpfs", - Options: []string{"nosuid", "strictatime", "mode=755", "size=65536k"}, - }, - // cgroup mount is always added by default, regardless if it is present - // in the mount constraints or not. If the user chooses to override it, - // then a corresponding mount constraint should be present. - { - Source: "cgroup", - Destination: "/sys/fs/cgroup", - Type: "cgroup", - Options: []string{"nosuid", "noexec", "nodev", "relatime", "ro"}, - }, - } -} - -// DefaultCRIPrivilegedMounts returns a slice of mounts which are added to the -// linux container spec when a container runs in a privileged mode. -func DefaultCRIPrivilegedMounts() []oci.Mount { - return []oci.Mount{ - { - Source: "cgroup", - Destination: "/sys/fs/cgroup", - Type: "cgroup", - Options: []string{"nosuid", "noexec", "nodev", "relatime", "rw"}, - }, - { - Destination: "/sys", - Type: "sysfs", - Source: "sysfs", - Options: []string{"nosuid", "noexec", "nodev", "rw"}, - }, - } -} diff --git a/internal/guest/policy/doc.go b/internal/guest/policy/doc.go deleted file mode 100644 index 8cbf7ee3fc..0000000000 --- a/internal/guest/policy/doc.go +++ /dev/null @@ -1 +0,0 @@ -package policy diff --git a/internal/guest/prot/protocol.go b/internal/guest/prot/protocol.go index 576ac5e5f1..16e1f9daa3 100644 --- a/internal/guest/prot/protocol.go +++ b/internal/guest/prot/protocol.go @@ -583,15 +583,15 @@ func UnmarshalContainerModifySettings(b []byte) (*containerModifySettings, error } msr.Settings = cc case guestresource.ResourceTypeSecurityPolicy: - enforcer := &guestresource.LCOWConfidentialOptions{} + enforcer := &guestresource.ConfidentialOptions{} if err := commonutils.UnmarshalJSONWithHresult(msrRawSettings, enforcer); err != nil { - return &request, errors.Wrap(err, "failed to unmarshal settings as LCOWConfidentialOptions") + return &request, errors.Wrap(err, "failed to unmarshal settings as ConfidentialOptions") } msr.Settings = enforcer case guestresource.ResourceTypePolicyFragment: - fragment := &guestresource.LCOWSecurityPolicyFragment{} + fragment := &guestresource.SecurityPolicyFragment{} if err := commonutils.UnmarshalJSONWithHresult(msrRawSettings, fragment); err != nil { - return &request, errors.Wrap(err, "failed to unmarshal settings as LCOWSecurityPolicyFragment") + return &request, errors.Wrap(err, "failed to unmarshal settings as SecurityPolicyFragment") } msr.Settings = fragment default: diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index 62f8ca3e43..2a13a6054b 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -73,6 +73,9 @@ type Container struct { // and deal with the extra pointer dereferencing overhead. status atomic.Uint32 + // Set to true when the init process for the container has exited + terminated atomic.Bool + // scratchDirPath represents the path inside the UVM where the scratch directory // of this container is located. Usually, this is either `/run/gcs/c/` or // `/run/gcs/c//container_` if scratch is shared with UVM scratch. diff --git a/internal/guest/runtime/hcsv2/hostdata.go b/internal/guest/runtime/hcsv2/hostdata.go deleted file mode 100644 index d75463fcda..0000000000 --- a/internal/guest/runtime/hcsv2/hostdata.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build linux -// +build linux - -package hcsv2 - -import ( - "bytes" - "fmt" - - "github.com/Microsoft/hcsshim/pkg/amdsevsnp" -) - -// validateHostData fetches SNP report (if applicable) and validates `hostData` against -// HostData set at UVM launch. -func validateHostData(hostData []byte) error { - // If the UVM is not SNP, then don't try to fetch an SNP report. - if !amdsevsnp.IsSNP() { - return nil - } - report, err := amdsevsnp.FetchParsedSNPReport(nil) - if err != nil { - return err - } - - if !bytes.Equal(hostData, report.HostData) { - return fmt.Errorf( - "security policy digest %q doesn't match HostData provided at launch %q", - hostData, - report.HostData, - ) - } - return nil -} diff --git a/internal/guest/runtime/hcsv2/process.go b/internal/guest/runtime/hcsv2/process.go index e94c9792f6..96564cfab0 100644 --- a/internal/guest/runtime/hcsv2/process.go +++ b/internal/guest/runtime/hcsv2/process.go @@ -99,6 +99,7 @@ func newProcess(c *Container, spec *oci.Process, process runtime.Process, pid ui log.G(ctx).WithError(err).Error("failed to wait for runc process") } p.exitCode = exitCode + c.terminated.Store(true) log.G(ctx).WithField("exitCode", p.exitCode).Debug("process exited") // Free any process waiters diff --git a/internal/guest/runtime/hcsv2/sandbox_container.go b/internal/guest/runtime/hcsv2/sandbox_container.go index 7456e1462a..da29a95835 100644 --- a/internal/guest/runtime/hcsv2/sandbox_container.go +++ b/internal/guest/runtime/hcsv2/sandbox_container.go @@ -54,6 +54,9 @@ func setupSandboxContainerSpec(ctx context.Context, id string, spec *oci.Spec) ( // Write the hostname hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() diff --git a/internal/guest/runtime/hcsv2/standalone_container.go b/internal/guest/runtime/hcsv2/standalone_container.go index bb1c5ad390..296b328cf5 100644 --- a/internal/guest/runtime/hcsv2/standalone_container.go +++ b/internal/guest/runtime/hcsv2/standalone_container.go @@ -61,6 +61,9 @@ func setupStandaloneContainerSpec(ctx context.Context, id string, spec *oci.Spec }() hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 69cfedb68f..0052ccf8f7 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -13,8 +13,10 @@ import ( "os/exec" "path" "path/filepath" + "regexp" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -24,11 +26,11 @@ import ( "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "go.opencensus.io/trace" "golang.org/x/sys/unix" "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" "github.com/Microsoft/hcsshim/internal/debug" - "github.com/Microsoft/hcsshim/internal/guest/policy" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime" specGuest "github.com/Microsoft/hcsshim/internal/guest/spec" @@ -40,8 +42,10 @@ import ( "github.com/Microsoft/hcsshim/internal/guest/storage/pmem" "github.com/Microsoft/hcsshim/internal/guest/storage/scsi" "github.com/Microsoft/hcsshim/internal/guest/transport" + "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/oci" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" @@ -54,6 +58,27 @@ import ( // for V2 where the specific message is targeted at the UVM itself. const UVMContainerID = "00000000-0000-0000-0000-000000000000" +// Prevent path traversal via malformed container / sandbox IDs. Container IDs +// can be either UVMContainerID, or a 64 character hex string. This is also used +// to check that sandbox IDs (which is also used in paths) are valid, which has +// the same format. +const validContainerIDRegexRaw = `[0-9a-fA-F]{64}` + +var validContainerIDRegex = regexp.MustCompile("^" + validContainerIDRegexRaw + "$") + +// idType just changes the error message +func checkValidContainerID(id string, idType string) error { + if id == UVMContainerID { + return nil + } + + if !validContainerIDRegex.MatchString(id) { + return errors.Errorf("invalid %s id: %s (must match %s)", idType, id, validContainerIDRegex.String()) + } + + return nil +} + // VirtualPod represents a virtual pod that shares a UVM/Sandbox with other pods type VirtualPod struct { VirtualSandboxID string @@ -85,118 +110,52 @@ type Host struct { devNullTransport transport.Transport // state required for the security policy enforcement - policyMutex sync.Mutex - securityPolicyEnforcer securitypolicy.SecurityPolicyEnforcer - securityPolicyEnforcerSet bool - uvmReferenceInfo string + securityOptions *securitypolicy.SecurityOptions - // logging target - logWriter io.Writer // hostMounts keeps the state of currently mounted devices and file systems, - // which is used for GCS hardening. + // which is used for GCS hardening. It is only used for confidential + // containers, and is initialized in SetConfidentialUVMOptions. If this is + // nil, we do not do add any special restrictions on mounts / unmounts. hostMounts *hostMounts + // A permanent flag to indicate that further mounts, unmounts and container + // creation should not be allowed. This is set when, because of a failure + // during an unmount operation, we end up in a state where the policy + // enforcer's state is out of sync with what we have actually done, but we + // cannot safely revert its state. + // + // Not used in non-confidential mode. + mountsBroken atomic.Bool + // A user-friendly error message for why mountsBroken was set. + mountsBrokenCausedBy string } func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Host { - return &Host{ - containers: make(map[string]*Container), - externalProcesses: make(map[int]*externalProcess), - virtualPods: make(map[string]*VirtualPod), - containerToVirtualPod: make(map[string]string), - rtime: rtime, - vsock: vsock, - devNullTransport: &transport.DevNullTransport{}, - securityPolicyEnforcerSet: false, - securityPolicyEnforcer: initialEnforcer, - logWriter: logWriter, - hostMounts: newHostMounts(), - } -} - -// SetConfidentialUVMOptions takes guestresource.LCOWConfidentialOptions -// to set up our internal data structures we use to store and enforce -// security policy. The options can contain security policy enforcer type, -// encoded security policy and signed UVM reference information The security -// policy and uvm reference information can be further presented to workload -// containers for validation and attestation purposes. -func (h *Host) SetConfidentialUVMOptions(ctx context.Context, r *guestresource.LCOWConfidentialOptions) error { - h.policyMutex.Lock() - defer h.policyMutex.Unlock() - if h.securityPolicyEnforcerSet { - return errors.New("security policy has already been set") - } - - // this limit ensures messages are below the character truncation limit that - // can be imposed by an orchestrator - maxErrorMessageLength := 3 * 1024 - - // Initialize security policy enforcer for a given enforcer type and - // encoded security policy. - p, err := securitypolicy.CreateSecurityPolicyEnforcer( - r.EnforcerType, - r.EncodedSecurityPolicy, - policy.DefaultCRIMounts(), - policy.DefaultCRIPrivilegedMounts(), - maxErrorMessageLength, + securityPolicyOptions := securitypolicy.NewSecurityOptions( + initialEnforcer, + false, + "", + logWriter, ) - if err != nil { - return err - } - - // This is one of two points at which we might change our logging. - // At this time, we now have a policy and can determine what the policy - // author put as policy around runtime logging. - // The other point is on startup where we take a flag to set the default - // policy enforcer to use before a policy arrives. After that flag is set, - // we use the enforcer in question to set up logging as well. - if err = p.EnforceRuntimeLoggingPolicy(ctx); err == nil { - logrus.SetOutput(h.logWriter) - } else { - logrus.SetOutput(io.Discard) - } - - hostData, err := securitypolicy.NewSecurityPolicyDigest(r.EncodedSecurityPolicy) - if err != nil { - return err - } - - if err := validateHostData(hostData[:]); err != nil { - return err + return &Host{ + containers: make(map[string]*Container), + externalProcesses: make(map[int]*externalProcess), + virtualPods: make(map[string]*VirtualPod), + containerToVirtualPod: make(map[string]string), + rtime: rtime, + vsock: vsock, + devNullTransport: &transport.DevNullTransport{}, + hostMounts: nil, + securityOptions: securityPolicyOptions, + mountsBroken: atomic.Bool{}, } - - h.securityPolicyEnforcer = p - h.securityPolicyEnforcerSet = true - h.uvmReferenceInfo = r.EncodedUVMReference - - return nil } -// InjectFragment extends current security policy with additional constraints -// from the incoming fragment. Note that it is base64 encoded over the bridge/ -// -// There are three checking steps: -// 1 - Unpack the cose document and check it was actually signed with the cert -// chain inside its header -// 2 - Check that the issuer field did:x509 identifier is for that cert chain -// (ie fingerprint of a non leaf cert and the subject matches the leaf cert) -// 3 - Check that this issuer/feed match the requirement of the user provided -// security policy (done in the regoby LoadFragment) -func (h *Host) InjectFragment(ctx context.Context, fragment *guestresource.LCOWSecurityPolicyFragment) (err error) { - log.G(ctx).WithField("fragment", fmt.Sprintf("%+v", fragment)).Debug("GCS Host.InjectFragment") - issuer, feed, payloadString, err := securitypolicy.ExtractAndVerifyFragment(ctx, fragment) - if err != nil { - return err - } - // now offer the payload fragment to the policy - err = h.securityPolicyEnforcer.LoadFragment(ctx, issuer, feed, payloadString) - if err != nil { - return fmt.Errorf("error loading security policy fragment: %w", err) - } - return nil +func (h *Host) SecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { + return h.securityOptions.PolicyEnforcer } -func (h *Host) SecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { - return h.securityPolicyEnforcer +func (h *Host) SecurityOptions() *securitypolicy.SecurityOptions { + return h.securityOptions } func (h *Host) Transport() transport.Transport { @@ -324,12 +283,105 @@ func setupSandboxLogDir(sandboxID, virtualSandboxID string) error { // TODO: unify workload and standalone logic for non-sandbox features (e.g., block devices, huge pages, uVM mounts) // TODO(go1.24): use [os.Root] instead of `!strings.HasPrefix(, )` +// Returns whether this host has a security policy set, i.e. if it's running +// confidential containers. +func (h *Host) HasSecurityPolicy() bool { + return len(h.securityOptions.PolicyEnforcer.EncodedSecurityPolicy()) > 0 +} + +// For confidential containers, make sure that the host can't use unexpected +// bundle paths / scratch dir / rootfs +func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHostedContainerSettingsV2) error { + if settings.OCISpecification == nil { + return errors.Errorf("OCISpecification is nil") + } + if settings.OCISpecification.Root == nil { + return errors.Errorf("OCISpecification.Root is nil") + } + + // matches with CreateContainer / createLinuxContainerDocument in internal/hcsoci + containerRootInUVM := path.Join(guestpath.LCOWRootPrefixInUVM, containerID) + if settings.OCIBundlePath != containerRootInUVM { + return errors.Errorf("OCIBundlePath %q must equal expected %q", + settings.OCIBundlePath, containerRootInUVM) + } + expectedContainerRootfs := path.Join(containerRootInUVM, guestpath.RootfsPath) + if settings.OCISpecification.Root.Path != expectedContainerRootfs { + return errors.Errorf("OCISpecification.Root.Path %q must equal expected %q", + settings.OCISpecification.Root.Path, expectedContainerRootfs) + } + + // matches with MountLCOWLayers + scratchDirPath := settings.ScratchDirPath + expectedScratchDirPathNonShared := path.Join(containerRootInUVM, guestpath.ScratchDir, containerID) + expectedScratchDirPathShared := path.Join(guestpath.LCOWRootPrefixInUVM, sandboxID, guestpath.ScratchDir, containerID) + if scratchDirPath != expectedScratchDirPathNonShared && + scratchDirPath != expectedScratchDirPathShared { + return errors.Errorf("ScratchDirPath %q must be either %q or %q", + scratchDirPath, expectedScratchDirPathNonShared, expectedScratchDirPathShared) + } + + if settings.OCISpecification.Hooks != nil { + return errors.Errorf("OCISpecification.Hooks must be nil.") + } + + return nil +} + +// Returns an error if h.mountsBroken is set (and we're in a confidential +// container host) +func (h *Host) checkMountsNotBroken() error { + if h.HasSecurityPolicy() && h.mountsBroken.Load() { + return errors.Errorf( + "Mount, unmount, container creation and deletion has been disabled in this UVM due to a previous error (%q)", + h.mountsBrokenCausedBy, + ) + } + return nil +} + +func (h *Host) setMountsBrokenIfConfidential(cause string) { + if !h.HasSecurityPolicy() { + return + } + h.mountsBroken.Store(true) + h.mountsBrokenCausedBy = cause + log.G(context.Background()).WithFields(logrus.Fields{ + "cause": cause, + }).Error("Host::mountsBroken set to true. All further mounts/unmounts, container creation and deletion will fail.") +} + +func checkExists(path string) (error, bool) { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return nil, false + } + return errors.Wrapf(err, "failed to determine if path '%s' exists", path), false + } + return nil, true +} + func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { + if err = h.checkMountsNotBroken(); err != nil { + return nil, err + } + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation virtualPodID, isVirtualPod := settings.OCISpecification.Annotations[annotations.VirtualPodID] + if h.HasSecurityPolicy() { + if err = checkValidContainerID(id, "container"); err != nil { + return nil, err + } + if virtualPodID != "" { + if err = checkValidContainerID(virtualPodID, "virtual pod"); err != nil { + return nil, err + } + } + } + // Special handling for virtual pod sandbox containers: // The first container in a virtual pod (containerID == virtualPodID) should be treated as a sandbox // even if the CRI annotation might indicate otherwise due to host-side UVM setup differences @@ -351,6 +403,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM isSandbox: criType == "sandbox", exitType: prot.NtUnexpectedExit, processes: make(map[uint32]*containerProcess), + terminated: atomic.Bool{}, scratchDirPath: settings.ScratchDirPath, } c.setStatus(containerCreating) @@ -466,12 +519,17 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, err } - if err := policy.ExtendPolicyWithNetworkingMounts(id, h.securityPolicyEnforcer, settings.OCISpecification); err != nil { + if err := securitypolicy.ExtendPolicyWithNetworkingMounts(id, h.securityOptions.PolicyEnforcer, settings.OCISpecification); err != nil { return nil, err } case "container": sid, ok := settings.OCISpecification.Annotations[annotations.KubernetesSandboxID] sandboxID = sid + if h.HasSecurityPolicy() { + if err = checkValidContainerID(sid, "sandbox"); err != nil { + return nil, err + } + } if !ok || sid == "" { return nil, errors.Errorf("unsupported 'io.kubernetes.cri.sandbox-id': '%s'", sid) } @@ -481,7 +539,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM // Add SEV device when security policy is not empty, except when privileged annotation is // set to "true", in which case all UVMs devices are added. - if len(h.securityPolicyEnforcer.EncodedSecurityPolicy()) > 0 && !oci.ParseAnnotationsBool(ctx, + if h.HasSecurityPolicy() && !oci.ParseAnnotationsBool(ctx, settings.OCISpecification.Annotations, annotations.LCOWPrivileged, false) { if err := specGuest.AddDevSev(ctx, settings.OCISpecification); err != nil { log.G(ctx).WithError(err).Debug("failed to add SEV device") @@ -493,7 +551,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM _ = os.RemoveAll(settings.OCIBundlePath) } }() - if err := policy.ExtendPolicyWithNetworkingMounts(sandboxID, h.securityPolicyEnforcer, settings.OCISpecification); err != nil { + if err := securitypolicy.ExtendPolicyWithNetworkingMounts(sandboxID, h.securityOptions.PolicyEnforcer, settings.OCISpecification); err != nil { return nil, err } default: @@ -510,7 +568,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM _ = os.RemoveAll(settings.OCIBundlePath) } }() - if err := policy.ExtendPolicyWithNetworkingMounts(id, h.securityPolicyEnforcer, + if err := securitypolicy.ExtendPolicyWithNetworkingMounts(id, h.securityOptions.PolicyEnforcer, settings.OCISpecification); err != nil { return nil, err } @@ -527,7 +585,13 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM }) } - user, groups, umask, err := h.securityPolicyEnforcer.GetUserInfo(settings.OCISpecification.Process, settings.OCISpecification.Root.Path) + if h.HasSecurityPolicy() { + if err = checkContainerSettings(sandboxID, id, settings); err != nil { + return nil, err + } + } + + user, groups, umask, err := h.securityOptions.PolicyEnforcer.GetUserInfo(settings.OCISpecification.Process, settings.OCISpecification.Root.Path) if err != nil { return nil, err } @@ -537,7 +601,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, err } - envToKeep, capsToKeep, allowStdio, err := h.securityPolicyEnforcer.EnforceCreateContainerPolicy( + envToKeep, capsToKeep, allowStdio, err := h.securityOptions.PolicyEnforcer.EnforceCreateContainerPolicy( ctx, sandboxID, id, @@ -595,47 +659,9 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM settings.OCISpecification.Process.Capabilities = capsToKeep } - // Write security policy, signed UVM reference and host AMD certificate to - // container's rootfs, so that application and sidecar containers can have - // access to it. The security policy is required by containers which need to - // extract init-time claims found in the security policy. The directory path - // containing the files is exposed via UVM_SECURITY_CONTEXT_DIR env var. - // It may be an error to have a security policy but not expose it to the - // container as in that case it can never be checked as correct by a verifier. if oci.ParseAnnotationsBool(ctx, settings.OCISpecification.Annotations, annotations.LCOWSecurityPolicyEnv, true) { - encodedPolicy := h.securityPolicyEnforcer.EncodedSecurityPolicy() - hostAMDCert := settings.OCISpecification.Annotations[annotations.LCOWHostAMDCertificate] - if len(encodedPolicy) > 0 || len(hostAMDCert) > 0 || len(h.uvmReferenceInfo) > 0 { - // Use os.MkdirTemp to make sure that the directory is unique. - securityContextDir, err := os.MkdirTemp(settings.OCISpecification.Root.Path, securitypolicy.SecurityContextDirTemplate) - if err != nil { - return nil, fmt.Errorf("failed to create security context directory: %w", err) - } - // Make sure that files inside directory are readable - if err := os.Chmod(securityContextDir, 0755); err != nil { - return nil, fmt.Errorf("failed to chmod security context directory: %w", err) - } - - if len(encodedPolicy) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.PolicyFilename, []byte(encodedPolicy), 0744); err != nil { - return nil, fmt.Errorf("failed to write security policy: %w", err) - } - } - if len(h.uvmReferenceInfo) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.ReferenceInfoFilename, []byte(h.uvmReferenceInfo), 0744); err != nil { - return nil, fmt.Errorf("failed to write UVM reference info: %w", err) - } - } - - if len(hostAMDCert) > 0 { - if err := writeFileInDir(securityContextDir, securitypolicy.HostAMDCertFilename, []byte(hostAMDCert), 0744); err != nil { - return nil, fmt.Errorf("failed to write host AMD certificate: %w", err) - } - } - - containerCtxDir := fmt.Sprintf("/%s", filepath.Base(securityContextDir)) - secCtxEnv := fmt.Sprintf("UVM_SECURITY_CONTEXT_DIR=%s", containerCtxDir) - settings.OCISpecification.Process.Env = append(settings.OCISpecification.Process.Env, secCtxEnv) + if err := h.securityOptions.WriteSecurityContextDir(settings.OCISpecification); err != nil { + return nil, fmt.Errorf("failed to write security context dir: %w", err) } } @@ -691,11 +717,40 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return c, nil } +// Returns whether there is a running container that is currently using the +// given overlay (as its rootfs). +func (h *Host) IsOverlayInUse(overlayPath string) bool { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + for _, c := range h.containers { + if c.terminated.Load() { + continue + } + + if c.spec.Root.Path == overlayPath { + return true + } + } + + return false +} + func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) (retErr error) { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + } + switch req.ResourceType { case guestresource.ResourceTypeSCSIDevice: return modifySCSIDevice(ctx, req.RequestType, req.Settings.(*guestresource.SCSIDevice)) case guestresource.ResourceTypeMappedVirtualDisk: + if err := h.checkMountsNotBroken(); err != nil { + return err + } + mvd := req.Settings.(*guestresource.LCOWMappedVirtualDisk) // find the actual controller number on the bus and update the incoming request. var cNum uint8 @@ -704,47 +759,25 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * return err } mvd.Controller = cNum - // first we try to update the internal state for read-write attachments. - if !mvd.ReadOnly { - localCtx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } - switch req.RequestType { - case guestrequest.RequestTypeAdd: - if err := h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, source) - } - }() - case guestrequest.RequestTypeRemove: - if err := h.hostMounts.RemoveRWDevice(mvd.MountPath, source); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted) - } - }() - } - } - return modifyMappedVirtualDisk(ctx, req.RequestType, mvd, h.securityPolicyEnforcer) + return h.modifyMappedVirtualDisk(ctx, req.RequestType, mvd) case guestresource.ResourceTypeMappedDirectory: - return modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory), h.securityPolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory)) case guestresource.ResourceTypeVPMemDevice: - return modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice), h.securityPolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice)) case guestresource.ResourceTypeCombinedLayers: - cl := req.Settings.(*guestresource.LCOWCombinedLayers) - // when cl.ScratchPath == "", we mount overlay as read-only, in which case - // we don't really care about scratch encryption, since the host already - // knows about the layers and the overlayfs. - encryptedScratch := cl.ScratchPath != "" && h.hostMounts.IsEncrypted(cl.ScratchPath) - return modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch, h.securityPolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers)) case guestresource.ResourceTypeNetwork: return modifyNetwork(ctx, req.RequestType, req.Settings.(*guestresource.LCOWNetworkAdapter)) case guestresource.ResourceTypeVPCIDevice: @@ -756,23 +789,44 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * } return c.modifyContainerConstraints(ctx, req.RequestType, req.Settings.(*guestresource.LCOWContainerConstraints)) case guestresource.ResourceTypeSecurityPolicy: - r, ok := req.Settings.(*guestresource.LCOWConfidentialOptions) + r, ok := req.Settings.(*guestresource.ConfidentialOptions) if !ok { - return errors.New("the request's settings are not of type LCOWConfidentialOptions") + return errors.New("the request's settings are not of type ConfidentialOptions") + } + err := h.securityOptions.SetConfidentialOptions(ctx, + r.EnforcerType, + r.EncodedSecurityPolicy, + r.EncodedUVMReference) + if err != nil { + return err + } + + // Start tracking mounts and restricting unmounts on confidential containers. + // As long as we started off with the ClosedDoorSecurityPolicyEnforcer, no + // mounts should have been allowed until this point. + if h.HasSecurityPolicy() { + log.G(ctx).Debug("hostMounts initialized") + h.hostMounts = newHostMounts() } - return h.SetConfidentialUVMOptions(ctx, r) + return nil case guestresource.ResourceTypePolicyFragment: - r, ok := req.Settings.(*guestresource.LCOWSecurityPolicyFragment) + r, ok := req.Settings.(*guestresource.SecurityPolicyFragment) if !ok { - return errors.New("the request settings are not of type LCOWSecurityPolicyFragment") + return errors.New("the request settings are not of type SecurityPolicyFragment") } - return h.InjectFragment(ctx, r) + return h.securityOptions.InjectFragment(ctx, r) default: return errors.Errorf("the ResourceType %q is not supported for UVM", req.ResourceType) } } func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + } + c, err := h.GetCreatedContainer(containerID) if err != nil { return err @@ -806,7 +860,7 @@ func (h *Host) ShutdownContainer(ctx context.Context, containerID string, gracef return err } - err = h.securityPolicyEnforcer.EnforceShutdownContainerPolicy(ctx, containerID) + err = h.securityOptions.PolicyEnforcer.EnforceShutdownContainerPolicy(ctx, containerID) if err != nil { return err } @@ -833,7 +887,7 @@ func (h *Host) SignalContainerProcess(ctx context.Context, containerID string, p signalingInitProcess := processID == c.initProcess.pid startupArgList := p.(*containerProcess).spec.Args - err = h.securityPolicyEnforcer.EnforceSignalContainerProcessPolicy(ctx, containerID, signal, signalingInitProcess, startupArgList) + err = h.securityOptions.PolicyEnforcer.EnforceSignalContainerProcessPolicy(ctx, containerID, signal, signalingInitProcess, startupArgList) if err != nil { return err } @@ -848,7 +902,7 @@ func (h *Host) ExecProcess(ctx context.Context, containerID string, params prot. if params.IsExternal || containerID == UVMContainerID { var envToKeep securitypolicy.EnvList var allowStdioAccess bool - envToKeep, allowStdioAccess, err = h.securityPolicyEnforcer.EnforceExecExternalProcessPolicy( + envToKeep, allowStdioAccess, err = h.securityOptions.PolicyEnforcer.EnforceExecExternalProcessPolicy( ctx, params.CommandArgs, processParamEnvToOCIEnv(params.Environment), @@ -890,12 +944,12 @@ func (h *Host) ExecProcess(ctx context.Context, containerID string, params prot. var umask string var allowStdioAccess bool - user, groups, umask, err = h.securityPolicyEnforcer.GetUserInfo(params.OCIProcess, c.spec.Root.Path) + user, groups, umask, err = h.securityOptions.PolicyEnforcer.GetUserInfo(params.OCIProcess, c.spec.Root.Path) if err != nil { return 0, err } - envToKeep, capsToKeep, allowStdioAccess, err = h.securityPolicyEnforcer.EnforceExecInContainerPolicy( + envToKeep, capsToKeep, allowStdioAccess, err = h.securityOptions.PolicyEnforcer.EnforceExecInContainerPolicy( ctx, containerID, params.OCIProcess.Args, @@ -944,7 +998,7 @@ func (h *Host) GetExternalProcess(pid int) (Process, error) { } func (h *Host) GetProperties(ctx context.Context, containerID string, query prot.PropertyQuery) (*prot.PropertiesV2, error) { - err := h.securityPolicyEnforcer.EnforceGetPropertiesPolicy(ctx) + err := h.securityOptions.PolicyEnforcer.EnforceGetPropertiesPolicy(ctx) if err != nil { return nil, errors.Wrapf(err, "get properties denied due to policy") } @@ -1000,7 +1054,7 @@ func (h *Host) GetProperties(ctx context.Context, containerID string, query prot } func (h *Host) GetStacks(ctx context.Context) (string, error) { - err := h.securityPolicyEnforcer.EnforceDumpStacksPolicy(ctx) + err := h.securityOptions.PolicyEnforcer.EnforceDumpStacksPolicy(ctx) if err != nil { return "", errors.Wrapf(err, "dump stacks denied due to policy") } @@ -1123,34 +1177,81 @@ func modifySCSIDevice( } } -func modifyMappedVirtualDisk( +func (h *Host) modifyMappedVirtualDisk( ctx context.Context, rt guestrequest.RequestType, mvd *guestresource.LCOWMappedVirtualDisk, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyMappedVirtualDisk") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.Int64Attribute("controller", int64(mvd.Controller)), + trace.Int64Attribute("lun", int64(mvd.Lun)), + trace.Int64Attribute("partition", int64(mvd.Partition)), + trace.BoolAttribute("readOnly", mvd.ReadOnly), + trace.StringAttribute("mountPath", mvd.MountPath), + ) + var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer + devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) + if err != nil { + return err + } + span.AddAttributes(trace.StringAttribute("devicePath", devPath)) + if mvd.ReadOnly { // The only time the policy is empty, and we want it to be empty // is when no policy is provided, and we default to open door // policy. In any other case, e.g. explicit open door or any // other rego policy we would like to mount layers with verity. - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { - devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } + if h.HasSecurityPolicy() { verityInfo, err = verity.ReadVeritySuperBlock(ctx, devPath) if err != nil { return err } + if mvd.Filesystem != "" && mvd.Filesystem != "ext4" { + return errors.Errorf("filesystem must be ext4 for read-only scsi mounts") + } } } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() if mvd.MountPath != "" { + if h.HasSecurityPolicy() { + // The only option we allow if there is policy enforcement is + // "ro", and it must match the readonly field in the request. + mountOptionHasRo := false + for _, opt := range mvd.Options { + if opt == "ro" { + mountOptionHasRo = true + continue + } + return errors.Errorf("mounting scsi device controller %d lun %d onto %s: mount option %q denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath, opt) + } + if mvd.ReadOnly != mountOptionHasRo { + return errors.Errorf( + "mounting scsi device controller %d lun %d onto %s with mount option %q failed due to mount option mismatch: mvd.ReadOnly=%t but mountOptionHasRo=%t", + mvd.Controller, mvd.Lun, mvd.MountPath, strings.Join(mvd.Options, ","), mvd.ReadOnly, mountOptionHasRo, + ) + } + } if mvd.ReadOnly { var deviceHash string if verityInfo != nil { @@ -1160,6 +1261,42 @@ func modifyMappedVirtualDisk( if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + if err != nil { + return err + } + // Note: "When a function returns, its deferred calls are + // executed in last-in-first-out order." - so we are safe to + // call RemoveRODevice in this defer. + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath) + } + }() + } + } else { + err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) + if err != nil { + return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1168,6 +1305,12 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } + // Since we're rolling back the policy metadata (via the revertable + // section) on failure, we need to ensure that we have reverted all + // the side effects from this failed mount attempt, otherwise the + // Rego metadata is technically still inconsistent with reality. + // Mount cleans up the created directory and dm devices on failure, + // so we're good. return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, mvd.ReadOnly, mvd.Options, config) } @@ -1175,9 +1318,58 @@ func modifyMappedVirtualDisk( case guestrequest.RequestTypeRemove: if mvd.MountPath != "" { if mvd.ReadOnly { - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + } + }() + } + } else { + if err = securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } + } + // Check that the directory actually exists first, and if it does + // not then we just refuse to do anything, without closing the dm + // device or setting the mountsBroken flag. Policy metadata is + // still reverted to reflect the fact that we have not done + // anything. + // + // Note: we should not do this check before calling the policy + // enforcer, as otherwise we might inadvertently allow the host to + // find out whether an arbitrary path (which may point to sensitive + // data within a container rootfs) exists or not + if h.HasSecurityPolicy() { + err, exists := checkExists(mvd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting scsi device at %s failed: directory does not exist", mvd.MountPath) + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1186,8 +1378,11 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, - mvd.MountPath, config); err != nil { + err = scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, config) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting scsi device at %s failed: %v", mvd.MountPath, err), + ) return err } } @@ -1197,13 +1392,23 @@ func modifyMappedVirtualDisk( } } -func modifyMappedDirectory( +func (h *Host) modifyMappedDirectory( ctx context.Context, vsock transport.Transport, rt guestrequest.RequestType, md *guestresource.LCOWMappedDirectory, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + securityPolicy := h.securityOptions.PolicyEnforcer + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforcePlan9MountPolicy(ctx, md.MountPath) @@ -1211,6 +1416,15 @@ func modifyMappedDirectory( return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) } + if h.HasSecurityPolicy() { + if err = plan9.ValidateShareName(md.ShareName); err != nil { + return err + } + } + + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, plan9.Mount here must clean up + // everything if it fails, which it does do. return plan9.Mount(ctx, vsock, md.MountPath, md.ShareName, uint32(md.Port), md.ReadOnly) case guestrequest.RequestTypeRemove: err = securityPolicy.EnforcePlan9UnmountPolicy(ctx, md.MountPath) @@ -1218,20 +1432,28 @@ func modifyMappedDirectory( return errors.Wrapf(err, "unmounting plan9 device at %s denied by policy", md.MountPath) } - return storage.UnmountPath(ctx, md.MountPath, true) + // Note: storage.UnmountPath is nop if path does not exist. + err = storage.UnmountPath(ctx, md.MountPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting plan9 device at %s failed: %v", md.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } } -func modifyMappedVPMemDevice(ctx context.Context, +func (h *Host) modifyMappedVPMemDevice(ctx context.Context, rt guestrequest.RequestType, vpd *guestresource.LCOWMappedVPMemDevice, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer var deviceHash string - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { if vpd.MappingInfo != nil { return fmt.Errorf("multi mapping is not supported with verity") } @@ -1241,6 +1463,17 @@ func modifyMappedVPMemDevice(ctx context.Context, } deviceHash = verityInfo.RootDigest } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforceDeviceMountPolicy(ctx, vpd.MountPath, deviceHash) @@ -1248,13 +1481,39 @@ func modifyMappedVPMemDevice(ctx context.Context, return errors.Wrapf(err, "mounting pmem device %d onto %s denied by policy", vpd.DeviceNumber, vpd.MountPath) } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, pmem.Mount here must clean up + // everything if it fails, which it does do. return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) case guestrequest.RequestTypeRemove: - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { return errors.Wrapf(err, "unmounting pmem device from %s denied by policy", vpd.MountPath) } - return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + // Check that the directory actually exists first, and if it does not + // then we just refuse to do anything, without closing the dm-linear or + // dm-verity device or setting the mountsBroken flag. + // + // Similar to the reasoning in modifyMappedVirtualDisk, we should not do + // this check before calling the policy enforcer. + if h.HasSecurityPolicy() { + err, exists := checkExists(vpd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting pmem device at %s failed: directory does not exist", vpd.MountPath) + } + } + + err = pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting pmem device at %s failed: %v", vpd.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1269,15 +1528,73 @@ func modifyMappedVPCIDevice(ctx context.Context, rt guestrequest.RequestType, vp } } -func modifyCombinedLayers( +func (h *Host) modifyCombinedLayers( ctx context.Context, rt guestrequest.RequestType, cl *guestresource.LCOWCombinedLayers, - scratchEncrypted bool, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyCombinedLayers") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.StringAttribute("containerRootPath", cl.ContainerRootPath), + trace.StringAttribute("scratchPath", cl.ScratchPath), + ) + + securityPolicy := h.securityOptions.PolicyEnforcer + containerID := cl.ContainerID + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + + if h.hostMounts != nil { + // We will need this in multiple places, let's take the lock once here. + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + } + switch rt { case guestrequest.RequestTypeAdd: + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + + // We check this regardless of what the policy says, as long as we're in + // confidential mode. This matches with checkContainerSettings called for + // container creation request. + expectedContainerRootfs := path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) + if cl.ContainerRootPath != expectedContainerRootfs { + return fmt.Errorf("combined layers target %q does not match expected path %q", + cl.ContainerRootPath, expectedContainerRootfs) + } + + if cl.ScratchPath != "" { + // At this point, we do not know what the sandbox ID would be yet, so we + // have to allow anything reasonable. + scratchDirRegexStr := fmt.Sprintf( + "^%s/%s/%s/%s$", + guestpath.LCOWRootPrefixInUVM, + validContainerIDRegexRaw, + guestpath.ScratchDir, + containerID, + ) + scratchDirRegex := regexp.MustCompile(scratchDirRegexStr) + if !scratchDirRegex.MatchString(cl.ScratchPath) { + return fmt.Errorf("scratch path %q must match regex %q", + cl.ScratchPath, scratchDirRegexStr) + } + } + } layerPaths := make([]string, len(cl.Layers)) for i, layer := range cl.Layers { layerPaths[i] = layer.Path @@ -1292,23 +1609,68 @@ func modifyCombinedLayers( } else { upperdirPath = filepath.Join(cl.ScratchPath, "upper") workdirPath = filepath.Join(cl.ScratchPath, "work") + scratchEncrypted := false + if h.hostMounts != nil { + scratchEncrypted = h.hostMounts.IsEncrypted(cl.ScratchPath) + } if err := securityPolicy.EnforceScratchMountPolicy(ctx, cl.ScratchPath, scratchEncrypted); err != nil { return fmt.Errorf("scratch mounting denied by policy: %w", err) } } - if err := securityPolicy.EnforceOverlayMountPolicy(ctx, cl.ContainerID, layerPaths, cl.ContainerRootPath); err != nil { + if err = securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { return fmt.Errorf("overlay creation denied by policy: %w", err) } + if h.hostMounts != nil { + if err = h.hostMounts.AddOverlay(cl.ContainerRootPath, layerPaths, cl.ScratchPath); err != nil { + return err + } + defer func() { + if err != nil { + _, _ = h.hostMounts.RemoveOverlay(cl.ContainerRootPath) + } + }() + } + // Correctness for policy revertable section: + // MountLayer does two things - mkdir, then mount. On mount failure, the + // target directory is cleaned up. Therefore we're clean in terms of + // side effects. return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) case guestrequest.RequestTypeRemove: - if err := securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { + // cl.ContainerID is not set on remove requests, but rego checks that we can + // only umount previously mounted targets anyway + if err = securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { return errors.Wrap(err, "overlay removal denied by policy") } - return storage.UnmountPath(ctx, cl.ContainerRootPath, true) + // Check that no running container is using this overlay as its rootfs. + if h.HasSecurityPolicy() && h.IsOverlayInUse(cl.ContainerRootPath) { + return fmt.Errorf("overlay %q is in use by a running container", cl.ContainerRootPath) + } + + if h.hostMounts != nil { + var undoRemoveOverlay func() + if undoRemoveOverlay, err = h.hostMounts.RemoveOverlay(cl.ContainerRootPath); err != nil { + return err + } + defer func() { + if err != nil && undoRemoveOverlay != nil { + undoRemoveOverlay() + } + }() + } + + // Note: storage.UnmountPath is a no-op if the path does not exist. + err = storage.UnmountPath(ctx, cl.ContainerRootPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting overlay at %s failed: %v", cl.ContainerRootPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1377,20 +1739,6 @@ func isPrivilegedContainerCreationRequest(ctx context.Context, spec *specs.Spec) return oci.ParseAnnotationsBool(ctx, spec.Annotations, annotations.LCOWPrivileged, false) } -func writeFileInDir(dir string, filename string, data []byte, perm os.FileMode) error { - st, err := os.Stat(dir) - if err != nil { - return err - } - - if !st.IsDir() { - return fmt.Errorf("not a directory %q", dir) - } - - targetFilename := filepath.Join(dir, filename) - return os.WriteFile(targetFilename, data, perm) -} - // Virtual Pod Management Methods // InitializeVirtualPodSupport sets up the parent cgroup for virtual pods @@ -1612,3 +1960,59 @@ func setupVirtualPodHugePageMountsPath(virtualSandboxID string) error { return storage.MountRShared(mountPath) } + +// If *err is not nil, the section is rolled back, otherwise it is committed. +func (h *Host) commitOrRollbackPolicyRevSection( + ctx context.Context, + rev securitypolicy.RevertableSectionHandle, + err *error, +) { + if !h.HasSecurityPolicy() { + // Don't produce bogus log entries if we aren't in confidential mode, + // even though rev.Rollback would have been no-op. + return + } + if *err != nil { + rev.Rollback() + logrus.WithContext(ctx).WithError(*err).Warn("rolling back security policy revertable section due to error") + } else { + rev.Commit() + } +} + +func (h *Host) DeleteContainerState(ctx context.Context, containerID string) error { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + } + + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + c, err := h.GetCreatedContainer(containerID) + if err != nil { + return err + } + if h.HasSecurityPolicy() { + if !c.terminated.Load() { + return errors.Errorf("Denied deleting state of a running container %q", containerID) + } + overlay := c.spec.Root.Path + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + if h.hostMounts.HasOverlayMountedAt(overlay) { + return errors.Errorf("Denied deleting state of a container with a overlay mount still active") + } + } + + // remove container state regardless of delete's success + defer h.RemoveContainer(containerID) + + if err = c.Delete(ctx); err != nil { + return err + } + + return nil +} diff --git a/internal/guest/runtime/hcsv2/uvm_state.go b/internal/guest/runtime/hcsv2/uvm_state.go index dd1ff521f0..96e64371a2 100644 --- a/internal/guest/runtime/hcsv2/uvm_state.go +++ b/internal/guest/runtime/hcsv2/uvm_state.go @@ -4,91 +4,360 @@ package hcsv2 import ( + "context" + "errors" "fmt" "path/filepath" "strings" "sync" + + "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/sirupsen/logrus" +) + +type deviceType int + +const ( + DeviceTypeRW deviceType = iota + DeviceTypeRO + DeviceTypeOverlay ) -type rwDevice struct { +func (d deviceType) String() string { + switch d { + case DeviceTypeRW: + return "RW" + case DeviceTypeRO: + return "RO" + case DeviceTypeOverlay: + return "Overlay" + default: + return fmt.Sprintf("Unknown(%d)", d) + } +} + +type device struct { + // fields common to all mountPath string + ty deviceType + usage int sourcePath string - encrypted bool + + // rw devices + encrypted bool + + // overlay devices + referencedDevices []*device } +// hostMounts tracks the state of fs/overlay mounts and their usage +// relationship. Users of this struct must call hm.Lock() before calling any +// other methods and call hm.Unlock() when done. +// +// Since mount/unmount operations can fail, the expected way to use this struct +// is to first lock it, call the method to add/remove the device, then, with the +// lock still held, perform the actual operation. If the operation fails, the +// caller must undo the operation by calling the appropriate remove/add method +// or the returned undo function, before unlocking. type hostMounts struct { - stateMutex sync.Mutex + stateMutex sync.Mutex + stateMutexLocked bool - // Holds information about read-write devices, which can be encrypted and - // contain overlay fs upper/work directory mounts. - readWriteMounts map[string]*rwDevice + // Map from mountPath to device struct + devices map[string]*device } func newHostMounts() *hostMounts { return &hostMounts{ - readWriteMounts: map[string]*rwDevice{}, + devices: make(map[string]*device), } } -// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. -// Returns an error if there's an existing device mounted at `mountPath` location. -func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { +func (hm *hostMounts) expectLocked() { + if !hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: expected stateMutex to be locked, but it was not")) + } +} + +// Locks the state mutex. This is not re-entrant, calling it twice in the same +// thread will deadlock/panic. +func (hm *hostMounts) Lock() { hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + // Since we just acquired the lock, either it was not locked before, or + // somebody just unlocked it. Either case, hm.stateMutexLocked should be + // false. + if hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: stateMutexLocked already true when locking stateMutex")) + } + hm.stateMutexLocked = true +} + +// Unlocks the state mutex +func (hm *hostMounts) Unlock() { + hm.expectLocked() + hm.stateMutexLocked = false + hm.stateMutex.Unlock() +} - mountTarget := filepath.Clean(mountPath) - if source, ok := hm.readWriteMounts[mountTarget]; ok { - return fmt.Errorf("read-write with source %q and mount target %q already exists", source.sourcePath, mountPath) +func (hm *hostMounts) findDeviceAtPath(mountPath string) *device { + hm.expectLocked() + + if dev, ok := hm.devices[mountPath]; ok { + return dev } - hm.readWriteMounts[mountTarget] = &rwDevice{ - mountPath: mountTarget, - sourcePath: sourcePath, - encrypted: encrypted, + return nil +} + +func (hm *hostMounts) addDeviceToMapChecked(dev *device) error { + hm.expectLocked() + + if _, ok := hm.devices[dev.mountPath]; ok { + return fmt.Errorf("device at mount path %q already exists", dev.mountPath) } + hm.devices[dev.mountPath] = dev return nil } -// RemoveRWDevice removes the read-write device metadata for device mounted at -// `mountPath`. -func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string) error { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() +func (hm *hostMounts) findDeviceContainingPath(path string) *device { + hm.expectLocked() + + // TODO: can we refactor this function by walking each component of the path + // from leaf to root, each time checking if the current component is a mount + // point? (i.e. why do we have to use filepath.Rel?) + + var foundDev *device + cleanPath := filepath.Clean(path) + for devPath, dev := range hm.devices { + relPath, err := filepath.Rel(devPath, cleanPath) + // skip further checks if an error is returned or the relative path + // contains "..", meaning that the `path` isn't directly nested under + // `rwPath`. + if err != nil || strings.HasPrefix(relPath, "..") { + continue + } + if foundDev == nil { + foundDev = dev + } else if len(dev.mountPath) > len(foundDev.mountPath) { + // The current device is mounted on top of a previously found device. + foundDev = dev + } + } + return foundDev +} + +func (hm *hostMounts) usePath(path string) (*device, error) { + hm.expectLocked() + + // Find the device at the given path and increment its usage count. + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return nil, nil + } + dev.usage++ + return dev, nil +} + +func (hm *hostMounts) releaseDeviceUsage(dev *device) { + hm.expectLocked() + + if dev.usage <= 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": dev.mountPath, + "deviceSource": dev.sourcePath, + "deviceType": dev.ty, + "usage": dev.usage, + }).Error("hostMounts::releaseDeviceUsage: unexpected zero usage count") + return + } + dev.usage-- +} + +// User should carefully handle side-effects of adding a device if the device +// fails to be added. +func (hm *hostMounts) doAddDevice(mountPath string, ty deviceType, sourcePath string) (*device, error) { + hm.expectLocked() + + dev := &device{ + mountPath: filepath.Clean(mountPath), + ty: ty, + usage: 0, + sourcePath: sourcePath, + } + + if err := hm.addDeviceToMapChecked(dev); err != nil { + return nil, err + } + return dev, nil +} + +// Once checks is called, unless it returns an error, this function will always +// succeed +func (hm *hostMounts) doRemoveDevice(mountPath string, ty deviceType, sourcePath string, checks func(*device) error) error { + hm.expectLocked() unmountTarget := filepath.Clean(mountPath) - device, ok := hm.readWriteMounts[unmountTarget] - if !ok { + device := hm.findDeviceAtPath(unmountTarget) + if device == nil { // already removed or didn't exist return nil } if device.sourcePath != sourcePath { - return fmt.Errorf("wrong sourcePath %s", sourcePath) + return fmt.Errorf("wrong sourcePath %s, expected %s", sourcePath, device.sourcePath) + } + if device.ty != ty { + return fmt.Errorf("wrong device type %s, expected %s", ty, device.ty) + } + if device.usage > 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": device.mountPath, + "deviceSource": device.sourcePath, + "deviceType": device.ty, + "usage": device.usage, + }).Error("hostMounts::doRemoveDevice: device still in use, refusing unmount") + return fmt.Errorf("device at %q is still in use, can't unmount", unmountTarget) + } + if checks != nil { + if err := checks(device); err != nil { + return err + } } - delete(hm.readWriteMounts, unmountTarget) + delete(hm.devices, unmountTarget) return nil } +func (hm *hostMounts) AddRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + _, err := hm.doAddDevice(mountPath, DeviceTypeRO, sourcePath) + return err +} + +// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. +// Returns an error if there's an existing device mounted at `mountPath` location. +func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeRW, sourcePath) + if err != nil { + return err + } + dev.encrypted = encrypted + return nil +} + +func (hm *hostMounts) AddOverlay(mountPath string, layers []string, scratchDir string) (err error) { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeOverlay, mountPath) + if err != nil { + return err + } + dev.referencedDevices = make([]*device, 0, len(layers)+1) + defer func() { + if err != nil { + // If we failed to use any of the paths, we need to release the ones + // that we did use. + for _, d := range dev.referencedDevices { + hm.releaseDeviceUsage(d) + } + delete(hm.devices, mountPath) + } + }() + + for _, layer := range layers { + refDev, err := hm.usePath(layer) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + } + refDev, err := hm.usePath(scratchDir) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + + return nil +} + +func (hm *hostMounts) RemoveRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRO, sourcePath, nil) +} + +// RemoveRWDevice removes the read-write device metadata for device mounted at +// `mountPath`. +func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRW, sourcePath, func(dev *device) error { + if dev.encrypted != encrypted { + return fmt.Errorf("encrypted flag wrong, provided %v, expected %v", encrypted, dev.encrypted) + } + return nil + }) +} + +func (hm *hostMounts) RemoveOverlay(mountPath string) (undo func(), err error) { + hm.expectLocked() + + var dev *device + err = hm.doRemoveDevice(mountPath, DeviceTypeOverlay, mountPath, func(_dev *device) error { + dev = _dev + for _, refDev := range dev.referencedDevices { + hm.releaseDeviceUsage(refDev) + } + return nil + }) + if err != nil { + // If we get an error from doRemoveDevice, we have not released anything + // yet. + return nil, err + } + undo = func() { + hm.expectLocked() + + for _, refDev := range dev.referencedDevices { + refDev.usage++ + } + + if _, ok := hm.devices[mountPath]; ok { + log.G(context.Background()).WithField("mountPath", mountPath).Error( + "hostMounts::RemoveOverlay: failed to undo remove: device that was removed exists in map", + ) + return + } + + hm.devices[mountPath] = dev + } + return undo, nil +} + // IsEncrypted checks if the given path is a sub-path of an encrypted read-write // device. func (hm *hostMounts) IsEncrypted(path string) bool { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + hm.expectLocked() - parentPath := "" - encrypted := false - cleanPath := filepath.Clean(path) - for rwPath, rwDev := range hm.readWriteMounts { - relPath, err := filepath.Rel(rwPath, cleanPath) - // skip further checks if an error is returned or the relative path - // contains "..", meaning that the `path` isn't directly nested under - // `rwPath`. - if err != nil || strings.HasPrefix(relPath, "..") { - continue - } - if len(rwDev.mountPath) > len(parentPath) { - parentPath = rwDev.mountPath - encrypted = rwDev.encrypted - } + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return false + } + return dev.encrypted +} + +func (hm *hostMounts) HasOverlayMountedAt(path string) bool { + hm.expectLocked() + + dev := hm.findDeviceAtPath(filepath.Clean(path)) + if dev == nil { + return false } - return encrypted + return dev.ty == DeviceTypeOverlay } diff --git a/internal/guest/runtime/hcsv2/uvm_state_test.go b/internal/guest/runtime/hcsv2/uvm_state_test.go index b708caaeba..e87a207308 100644 --- a/internal/guest/runtime/hcsv2/uvm_state_test.go +++ b/internal/guest/runtime/hcsv2/uvm_state_test.go @@ -12,10 +12,13 @@ func Test_Add_Remove_RWDevice(t *testing.T) { mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" + hm.Lock() + defer hm.Unlock() + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error adding RW device: %s", err) } - if err := hm.RemoveRWDevice(mountPath, sourcePath); err != nil { + if err := hm.RemoveRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error removing RW device: %s", err) } } @@ -25,29 +28,55 @@ func Test_Cannot_AddRWDevice_Twice(t *testing.T) { mountPath := "/run/gcs/c/abc" sourcePath := "/dev/sda" + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } + hm.Unlock() + + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err == nil { t.Fatalf("expected error adding %q for the second time", mountPath) } + hm.Unlock() } func Test_Cannot_RemoveRWDevice_Wrong_Source(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" wrongSource := "/dev/sdb" if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } - if err := hm.RemoveRWDevice(mountPath, wrongSource); err == nil { + if err := hm.RemoveRWDevice(mountPath, wrongSource, false); err == nil { t.Fatalf("expected error removing wrong source %s", wrongSource) } } +func Test_Cannot_RemoveRWDevice_Wrong_Encrypted(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.RemoveRWDevice(mountPath, sourcePath, true); err == nil { + t.Fatalf("expected error removing RW device with wrong encrypted flag") + } +} + func Test_HostMounts_IsEncrypted(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + encryptedPath := "/run/gcs/c/encrypted" encryptedSource := "/dev/sda" if err := hm.AddRWDevice(encryptedPath, encryptedSource, true); err != nil { @@ -108,3 +137,189 @@ func Test_HostMounts_IsEncrypted(t *testing.T) { }) } } + +func Test_HostMounts_AddRemoveRODevice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + + if err := hm.RemoveRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error removing RO device: %s", err) + } +} + +func Test_HostMounts_Cannot_AddRODevice_Twice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abc" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.AddRODevice(mountPath, sourcePath); err == nil { + t.Fatalf("expected error adding %q for the second time", mountPath) + } +} + +func Test_HostMounts_AddRemoveOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + undo, err := hm.RemoveOverlay(mountPath) + if err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + if undo == nil { + t.Fatalf("expected undo function to be non-nil") + } + undo() + if _, err = hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay again: %s", err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", scratchDir) + } + + if _, err := hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + + // now we can remove + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", scratchDir, err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay_MultipleUsers(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + overlay1 := "/run/gcs/c/aaaa/rootfs" + overlay2 := "/run/gcs/c/bbbb/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + sharedScratchMount := "/run/gcs/c/sandbox" + scratch1 := sharedScratchMount + "/scratch/aaaa" + scratch2 := sharedScratchMount + "/scratch/bbbb" + if err := hm.AddRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(overlay1, layers, scratch1); err != nil { + t.Fatalf("unexpected error adding overlay1: %s", err) + } + + if err := hm.AddOverlay(overlay2, layers[0:2], scratch2); err != nil { + t.Fatalf("unexpected error adding overlay2: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", sharedScratchMount) + } + + if _, err := hm.RemoveOverlay(overlay1); err != nil { + t.Fatalf("unexpected error removing overlay 1: %s", err) + } + + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s (still in use by overlay 2)", layer) + } + } + if err := hm.RemoveRODevice(layers[2], layers[2]); err != nil { + t.Fatalf("unexpected error removing layers[2] which is not being used by overlay 2: %s", err) + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay 2", scratch2) + } + + if _, err := hm.RemoveOverlay(overlay2); err != nil { + t.Fatalf("unexpected error removing overlay 2: %s", err) + } + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", sharedScratchMount, err) + } +} diff --git a/internal/guest/storage/mount.go b/internal/guest/storage/mount.go index a3d10a3b25..142f0ccbbc 100644 --- a/internal/guest/storage/mount.go +++ b/internal/guest/storage/mount.go @@ -16,6 +16,7 @@ import ( "go.opencensus.io/trace" "golang.org/x/sys/unix" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" ) @@ -126,6 +127,7 @@ func UnmountPath(ctx context.Context, target string, removeTarget bool) (err err if _, err := osStat(target); err != nil { if os.IsNotExist(err) { + log.G(ctx).WithField("target", target).Warnf("UnmountPath called for non-existent path") return nil } return errors.Wrapf(err, "failed to determine if path '%s' exists", target) diff --git a/internal/guest/storage/overlay/overlay.go b/internal/guest/storage/overlay/overlay.go index aa4877508f..84bf8fa529 100644 --- a/internal/guest/storage/overlay/overlay.go +++ b/internal/guest/storage/overlay/overlay.go @@ -56,8 +56,7 @@ func processErrNoSpace(ctx context.Context, path string, err error) { }).WithError(err).Warn("got ENOSPC, gathering diagnostics") } -// MountLayer first enforces the security policy for the container's layer paths -// and then calls Mount to mount the layer paths as an overlayfs. +// MountLayer calls Mount to mount the layer paths as an overlayfs. func MountLayer( ctx context.Context, layerPaths []string, diff --git a/internal/guest/storage/plan9/plan9.go b/internal/guest/storage/plan9/plan9.go index 5c1f1d74f4..44ac0f4e4e 100644 --- a/internal/guest/storage/plan9/plan9.go +++ b/internal/guest/storage/plan9/plan9.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "regexp" "syscall" "github.com/Microsoft/hcsshim/internal/guest/transport" @@ -25,6 +26,19 @@ var ( unixMount = unix.Mount ) +// c.f. v9fs_parse_options in linux/fs/9p/v9fs.c - technically anything other +// than ',' is ok (quoting is not handled), however, this name is generated from +// a counter in AddPlan9 (internal/uvm/plan9.go), and therefore we expect only +// digits from a normal hcsshim host. +var validShareNameRegex = regexp.MustCompile(`^[0-9]+$`) + +func ValidateShareName(name string) error { + if !validShareNameRegex.MatchString(name) { + return fmt.Errorf("invalid plan9 share name %q: must match regex %q", name, validShareNameRegex.String()) + } + return nil +} + // Mount dials a connection from `vsock` and mounts a Plan9 share to `target`. // // `target` will be created. On mount failure the created `target` will be diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index ec62636590..83c586c3eb 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -121,8 +121,9 @@ type Config struct { // Mount creates a mount from the SCSI device on `controller` index `lun` to // `target` // -// `target` will be created. On mount failure the created `target` will be -// automatically cleaned up. +// `target` will be created. On mount failure the created `target`, as well as +// any associated dm-crypt or dm-verify devices will be automatically cleaned +// up. // // If the config has `encrypted` is set to true, the SCSI device will be // encrypted using dm-crypt. @@ -200,7 +201,8 @@ func Mount( var deviceFS string if config.Encrypted { cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) - encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName) + var encryptedSource string + encryptedSource, err = encryptDevice(spnCtx, source, cryptDeviceName) if err != nil { // todo (maksiman): add better retry logic, similar to how SCSI device mounts are // retried on unix.ENOENT and unix.ENXIO. The retry should probably be on an @@ -211,6 +213,13 @@ func Mount( } } source = encryptedSource + defer func() { + if err != nil { + if err := cleanupCryptDevice(spnCtx, cryptDeviceName); err != nil { + log.G(spnCtx).WithError(err).WithField("cryptDeviceName", cryptDeviceName).Debug("failed to cleanup dm-crypt device after mount failure") + } + } + }() } else { // Get the filesystem that is already on the device (if any) and use that // as the mountType unless `Filesystem` was given. diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index ebfcf8e382..94992047bd 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -999,6 +999,12 @@ func Test_Mount_EncryptDevice_Mkfs_Error(t *testing.T) { } return expectedDevicePath, nil } + cleanupCryptDevice = func(_ context.Context, dmCryptName string) error { + if dmCryptName != expectedCryptTarget { + t.Fatalf("expected cleanupCryptDevice name %q got %q", expectedCryptTarget, dmCryptName) + } + return nil + } osStat = osStatNoop xfsFormat = func(arg string) error { diff --git a/internal/guestpath/paths.go b/internal/guestpath/paths.go index aab9ed1053..1852bc2454 100644 --- a/internal/guestpath/paths.go +++ b/internal/guestpath/paths.go @@ -27,15 +27,17 @@ const ( // LCOWMountPathPrefixFmt is the path format in the LCOW UVM where // non-global mounts, such as Plan9 mounts are added LCOWMountPathPrefixFmt = "/mounts/m%d" - // LCOWGlobalMountPrefixFmt is the path format in the LCOW UVM where global - // mounts are added - LCOWGlobalMountPrefixFmt = "/run/mounts/m%d" + // LCOWGlobalScsiMountPrefixFmt is the path format in the LCOW UVM where + // global desk mounts are added + LCOWGlobalScsiMountPrefixFmt = "/run/mounts/scsi/m%d" // LCOWGlobalDriverPrefixFmt is the path format in the LCOW UVM where drivers // are mounted as read/write LCOWGlobalDriverPrefixFmt = "/run/drivers/%s" - // WCOWGlobalMountPrefixFmt is the path prefix format in the WCOW UVM where - // mounts are added - WCOWGlobalMountPrefixFmt = "C:\\mounts\\m%d" + // WCOWGlobalScsiMountPrefixFmt is the path prefix format in the WCOW UVM + // where global desk mounts are added + WCOWGlobalScsiMountPrefixFmt = `c:\mounts\scsi\m%d` // RootfsPath is part of the container's rootfs path RootfsPath = "rootfs" + // ScratchDir is the name of the directory used for overlay upper and work + ScratchDir = "scratch" ) diff --git a/internal/layers/lcow.go b/internal/layers/lcow.go index dccd994e87..b1385fa4ac 100644 --- a/internal/layers/lcow.go +++ b/internal/layers/lcow.go @@ -159,7 +159,9 @@ func MountLCOWLayers( // handles the case where we want to share a scratch disk for multiple containers instead // of mounting a new one. Pass a unique value for `ScratchPath` to avoid container upper and // work directories colliding in the UVM. - containerScratchPathInUVM := ospath.Join("linux", scsiMount.GuestPath(), "scratch", containerID) + // Note that in the shared scratch case, AddVirtualDisk above is a no-op and + // will return the existing mount. + containerScratchPathInUVM := ospath.Join("linux", scsiMount.GuestPath(), guestpath.ScratchDir, containerID) defer func() { if err != nil { diff --git a/internal/protocol/guestresource/resources.go b/internal/protocol/guestresource/resources.go index b956069107..8a58949281 100644 --- a/internal/protocol/guestresource/resources.go +++ b/internal/protocol/guestresource/resources.go @@ -229,20 +229,14 @@ type SignalProcessOptionsWCOW struct { Signal guestrequest.SignalValueWCOW `json:",omitempty"` } -// LCOWConfidentialOptions is used to set various confidential container specific +// ConfidentialOptions is used to set various confidential container specific // options. -type LCOWConfidentialOptions struct { +type ConfidentialOptions struct { EnforcerType string `json:"EnforcerType,omitempty"` EncodedSecurityPolicy string `json:"EncodedSecurityPolicy,omitempty"` EncodedUVMReference string `json:"EncodedUVMReference,omitempty"` } -type LCOWSecurityPolicyFragment struct { +type SecurityPolicyFragment struct { Fragment string `json:"Fragment,omitempty"` } - -type WCOWConfidentialOptions struct { - EnforcerType string `json:"EnforcerType,omitempty"` - EncodedSecurityPolicy string `json:"EncodedSecurityPolicy,omitempty"` - EncodedUVMReference string `json:"EncodedUVMReference,omitempty"` -} diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter.go b/internal/regopolicyinterpreter/regopolicyinterpreter.go index 047a4a27b7..ebe3fcfff4 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter.go @@ -63,6 +63,9 @@ type RegoModule struct { type regoMetadata map[string]map[string]interface{} +const metadataRootKey = "metadata" +const metadataOperationsKey = "metadata" + type regoMetadataAction string const ( @@ -81,6 +84,11 @@ type regoMetadataOperation struct { // The result from a policy query type RegoQueryResult map[string]interface{} +// An immutable, saved copy of the metadata state. +type SavedMetadata struct { + metadataRoot regoMetadata +} + // deep copy for an object func copyObject(data map[string]interface{}) (map[string]interface{}, error) { objJSON, err := json.Marshal(data) @@ -113,6 +121,24 @@ func copyValue(value interface{}) (interface{}, error) { return valueCopy, nil } +// deep copy for regoMetadata. +// We cannot use copyObject for this due to the fact that map[string]interface{} +// is a concrete type and a map of it cannot be used as a map of interface{}. +func copyRegoMetadata(value regoMetadata) (regoMetadata, error) { + valueJSON, err := json.Marshal(value) + if err != nil { + return nil, err + } + + var valueCopy regoMetadata + err = json.Unmarshal(valueJSON, &valueCopy) + if err != nil { + return nil, err + } + + return valueCopy, nil +} + // NewRegoPolicyInterpreter creates a new RegoPolicyInterpreter, using the code provided. // inputData is the Rego data which should be used as the initial state // of the interpreter. A deep copy is performed on it such that it will @@ -123,8 +149,8 @@ func NewRegoPolicyInterpreter(code string, inputData map[string]interface{}) (*R return nil, fmt.Errorf("unable to copy the input data: %w", err) } - if _, ok := data["metadata"]; !ok { - data["metadata"] = make(regoMetadata) + if _, ok := data[metadataRootKey]; !ok { + data[metadataRootKey] = make(regoMetadata) } policy := &RegoPolicyInterpreter{ @@ -207,7 +233,7 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ r.dataAndModulesMutex.Lock() defer r.dataAndModulesMutex.Unlock() - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return nil, errors.New("illegal interpreter state: invalid metadata object type") } @@ -228,6 +254,32 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ } } +// Saves a copy of the internal policy metadata state. +func (r *RegoPolicyInterpreter) SaveMetadata() (s SavedMetadata, err error) { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) + if !ok { + return SavedMetadata{}, errors.New("illegal interpreter state: invalid metadata object type") + } + s.metadataRoot, err = copyRegoMetadata(metadataRoot) + return s, err +} + +// Restores a previously saved metadata state. +func (r *RegoPolicyInterpreter) RestoreMetadata(m SavedMetadata) error { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + copied, err := copyRegoMetadata(m.metadataRoot) + if err != nil { + return fmt.Errorf("unable to copy metadata: %w", err) + } + r.data[metadataRootKey] = copied + return nil +} + func newRegoMetadataOperation(operation interface{}) (*regoMetadataOperation, error) { var metadataOp regoMetadataOperation @@ -286,7 +338,7 @@ func (r *RegoPolicyInterpreter) UpdateOSType(os string) error { func (r *RegoPolicyInterpreter) updateMetadata(ops []*regoMetadataOperation) error { // dataAndModulesMutex must be held before calling this - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return errors.New("illegal interpreter state: invalid metadata object type") } @@ -431,7 +483,7 @@ func (r *RegoPolicyInterpreter) logMetadata() { return } - contents, err := json.Marshal(r.data["metadata"]) + contents, err := json.Marshal(r.data[metadataRootKey]) if err != nil { r.metadataLogger.Printf("error marshaling metadata: %v\n", err.Error()) } else { @@ -617,7 +669,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) r.logResult(rule, resultSet) ops := []*regoMetadataOperation{} - if rawMetadata, ok := resultSet["metadata"]; ok { + if rawMetadata, ok := resultSet[metadataOperationsKey]; ok { metadata, ok := rawMetadata.([]interface{}) if !ok { return nil, errors.New("error loading metadata array: invalid type") @@ -640,7 +692,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) } for name, value := range resultSet { - if name == "metadata" { + if name == metadataOperationsKey { continue } else { result[name] = value diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go index b7d86609f7..3872afff51 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go @@ -72,6 +72,37 @@ func Test_copyValue(t *testing.T) { } } +func Test_copyRegoMetadata(t *testing.T) { + f := func(orig testRegoMetadata) bool { + copy, err := copyRegoMetadata(regoMetadata(orig)) + if err != nil { + t.Error(err) + return false + } + + if len(orig) != len(copy) { + t.Errorf("original and copy have different number of objects: %d != %d", len(orig), len(copy)) + return false + } + + for name, origObject := range orig { + if copyObject, ok := copy[name]; ok { + if !assertObjectsEqual(origObject, copyObject) { + t.Errorf("original and copy differ on key %s", name) + } + } else { + t.Errorf("copy missing object %s", name) + } + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 30, Rand: testRand}); err != nil { + t.Errorf("Test_copyRegoMetadata: %v", err) + } +} + //go:embed test.rego var testCode string @@ -364,6 +395,107 @@ func Test_Metadata_Remove(t *testing.T) { } } +func Test_Metadata_SaveRestore(t *testing.T) { + rego, err := setupRego() + if err != nil { + t.Fatal(err) + } + + f := func(pairs1before, pairs1after intPairArray, name1 metadataName, pairs2before, pairs2after intPairArray, name2 metadataName) bool { + if name1 == name2 { + t.Fatalf("generated two identical names: %s", name1) + } + + err := appendAll(rego, pairs1before, name1) + if err != nil { + t.Errorf("error appending pairs1before: %v", err) + return false + } + err = appendAll(rego, pairs2before, name2) + if err != nil { + t.Errorf("error appending pairs2before: %v", err) + return false + } + + saved, err := rego.SaveMetadata() + if err != nil { + t.Errorf("unable to save metadata: %v", err) + return false + } + + beforeSum1 := getExpectedGapFromPairs(pairs1before) + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Error(err) + return false + } + + beforeSum2 := getExpectedGapFromPairs(pairs2before) + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Error(err) + return false + } + + // computeGap would have cleared the list, so we restore it. + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = appendAll(rego, pairs1after, name1) + if err != nil { + t.Errorf("error appending pairs1after: %v", err) + return false + } + + err = appendAll(rego, pairs2after, name2) + if err != nil { + t.Errorf("error appending pairs2after: %v", err) + return false + } + + afterSum1 := beforeSum1 + getExpectedGapFromPairs(pairs1after) + err = computeGap(rego, name1, afterSum1) + if err != nil { + t.Error(err) + return false + } + + afterSum2 := beforeSum2 + getExpectedGapFromPairs(pairs2after) + err = computeGap(rego, name2, afterSum2) + if err != nil { + t.Error(err) + return false + } + + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Errorf("computeGap failed for name1 after restore: %v", err) + return false + } + + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Errorf("computeGap failed for name2 after restore: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 100, Rand: testRand}); err != nil { + t.Errorf("Test_Metadata_SaveRestore: %v", err) + } +} + //go:embed module.rego var moduleCode string @@ -508,6 +640,7 @@ type testValue struct { } type testArray []interface{} type testObject map[string]interface{} +type testRegoMetadata regoMetadata type testValueType int @@ -580,6 +713,16 @@ func (testObject) Generate(r *rand.Rand, _ int) reflect.Value { return reflect.ValueOf(value) } +func (testRegoMetadata) Generate(r *rand.Rand, _ int) reflect.Value { + numObjects := r.Intn(maxNumberOfFields) + metadata := make(testRegoMetadata) + for i := 0; i < numObjects; i++ { + name := uniqueString(r) + metadata[name] = generateObject(r, 0) + } + return reflect.ValueOf(metadata) +} + func getResult(r *RegoPolicyInterpreter, p intPair, rule string) (RegoQueryResult, error) { input := map[string]interface{}{"a": p.a, "b": p.b} result, err := r.Query("data.test."+rule, input) @@ -640,6 +783,27 @@ func appendLists(r *RegoPolicyInterpreter, p intPair, name metadataName) error { return nil } +func appendAll(r *RegoPolicyInterpreter, pairs intPairArray, name metadataName) error { + for _, pair := range pairs { + if err := appendLists(r, pair, name); err != nil { + return fmt.Errorf("error appending pair %v: %w", pair, err) + } + } + return nil +} + +func getExpectedGapFromPairs(pairs intPairArray) int { + expected := 0 + for _, pair := range pairs { + if pair.a >= pair.b { + expected += pair.a - pair.b + } else { + expected += pair.b - pair.a + } + } + return expected +} + func computeGap(r *RegoPolicyInterpreter, name metadataName, expected int) error { input := map[string]interface{}{"name": string(name)} result, err := r.Query("data.test.compute_gap", input) diff --git a/internal/uvm/security_policy.go b/internal/uvm/security_policy.go index 0dcf4fe693..3fa47e87b1 100644 --- a/internal/uvm/security_policy.go +++ b/internal/uvm/security_policy.go @@ -16,11 +16,11 @@ import ( "github.com/Microsoft/hcsshim/pkg/ctrdtaskapi" ) -type ConfidentialUVMOpt func(ctx context.Context, r *guestresource.LCOWConfidentialOptions) error +type ConfidentialUVMOpt func(ctx context.Context, r *guestresource.ConfidentialOptions) error // WithSecurityPolicy sets the desired security policy for the resource. func WithSecurityPolicy(policy string) ConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.LCOWConfidentialOptions) error { + return func(ctx context.Context, r *guestresource.ConfidentialOptions) error { r.EncodedSecurityPolicy = policy return nil } @@ -28,75 +28,12 @@ func WithSecurityPolicy(policy string) ConfidentialUVMOpt { // WithSecurityPolicyEnforcer sets the desired enforcer type for the resource. func WithSecurityPolicyEnforcer(enforcer string) ConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.LCOWConfidentialOptions) error { + return func(ctx context.Context, r *guestresource.ConfidentialOptions) error { r.EnforcerType = enforcer return nil } } -// TODO (Mahati): Move this block out later -type WCOWConfidentialUVMOpt func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error - -// WithSecurityPolicy sets the desired security policy for the resource. -func WithWCOWSecurityPolicy(policy string) WCOWConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { - r.EncodedSecurityPolicy = policy - return nil - } -} - -// WithSecurityPolicyEnforcer sets the desired enforcer type for the resource. -func WithWCOWSecurityPolicyEnforcer(enforcer string) WCOWConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { - r.EnforcerType = enforcer - return nil - } -} - -// WithUVMReferenceInfo reads UVM reference info file and base64 encodes the -// content before setting it for the resource. This is no-op if the -// path is empty or the file doesn't exist. -func WithWCOWUVMReferenceInfo(path string) WCOWConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.WCOWConfidentialOptions) error { - encoded, err := base64EncodeFileContents(path) - if err != nil { - if os.IsNotExist(err) { - log.G(ctx).WithField("filePath", path).Debug("UVM reference info file not found") - return nil - } - return fmt.Errorf("failed to read UVM reference info file: %w", err) - } - r.EncodedUVMReference = encoded - return nil - } -} - -func (uvm *UtilityVM) SetWCOWConfidentialUVMOptions(ctx context.Context, opts ...WCOWConfidentialUVMOpt) error { - if uvm.operatingSystem != "windows" { - return errNotSupported - } - uvm.m.Lock() - defer uvm.m.Unlock() - confOpts := &guestresource.WCOWConfidentialOptions{} - for _, o := range opts { - if err := o(ctx, confOpts); err != nil { - return err - } - } - modification := &hcsschema.ModifySettingRequest{ - RequestType: guestrequest.RequestTypeAdd, - GuestRequest: guestrequest.ModificationRequest{ - ResourceType: guestresource.ResourceTypeSecurityPolicy, - RequestType: guestrequest.RequestTypeAdd, - Settings: *confOpts, - }, - } - if err := uvm.modify(ctx, modification); err != nil { - return fmt.Errorf("uvm::Policy: failed to modify utility VM configuration: %w", err) - } - return nil -} - func base64EncodeFileContents(filePath string) (string, error) { if filePath == "" { return "", nil @@ -112,7 +49,7 @@ func base64EncodeFileContents(filePath string) (string, error) { // content before setting it for the resource. This is no-op if the // `referenceName` is empty or the file doesn't exist. func WithUVMReferenceInfo(referenceRoot string, referenceName string) ConfidentialUVMOpt { - return func(ctx context.Context, r *guestresource.LCOWConfidentialOptions) error { + return func(ctx context.Context, r *guestresource.ConfidentialOptions) error { if referenceName == "" { return nil } @@ -137,14 +74,10 @@ func WithUVMReferenceInfo(referenceRoot string, referenceName string) Confidenti // This has to happen before we start mounting things or generally changing // the state of the UVM after is has been measured at startup func (uvm *UtilityVM) SetConfidentialUVMOptions(ctx context.Context, opts ...ConfidentialUVMOpt) error { - if uvm.operatingSystem != "linux" { - return errNotSupported - } - uvm.m.Lock() defer uvm.m.Unlock() - confOpts := &guestresource.LCOWConfidentialOptions{} + confOpts := &guestresource.ConfidentialOptions{} for _, o := range opts { if err := o(ctx, confOpts); err != nil { return err @@ -174,7 +107,7 @@ func (uvm *UtilityVM) InjectPolicyFragment(ctx context.Context, fragment *ctrdta GuestRequest: guestrequest.ModificationRequest{ ResourceType: guestresource.ResourceTypePolicyFragment, RequestType: guestrequest.RequestTypeAdd, - Settings: guestresource.LCOWSecurityPolicyFragment{ + Settings: guestresource.SecurityPolicyFragment{ Fragment: fragment.Fragment, }, }, diff --git a/internal/uvm/start.go b/internal/uvm/start.go index 321f5af67a..df99e275a5 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -19,6 +19,7 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/gcs/prot" + "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/hcs" "github.com/Microsoft/hcsshim/internal/hcs/schema1" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" @@ -310,9 +311,9 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { } else { gb = scsi.NewHCSGuestBackend(uvm.hcsSystem, uvm.OS()) } - guestMountFmt := `c:\mounts\scsi\m%d` + guestMountFmt := guestpath.WCOWGlobalScsiMountPrefixFmt if uvm.OS() == "linux" { - guestMountFmt = "/run/mounts/scsi/m%d" + guestMountFmt = guestpath.LCOWGlobalScsiMountPrefixFmt } mgr, err := scsi.NewManager( scsi.NewHCSHostBackend(uvm.hcsSystem), @@ -326,28 +327,29 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { } uvm.SCSIManager = mgr - if uvm.confidentialUVMOptions != nil && uvm.OS() == "linux" { + var policy, enforcer, referenceInfoFileRoot, referenceInfoFilePath string + + if uvm.confidentialUVMOptions != nil || uvm.HasConfidentialPolicy() { + if uvm.confidentialUVMOptions != nil && uvm.OS() == "linux" { + policy = uvm.confidentialUVMOptions.SecurityPolicy + enforcer = uvm.confidentialUVMOptions.SecurityPolicyEnforcer + referenceInfoFilePath = uvm.confidentialUVMOptions.UVMReferenceInfoFile + referenceInfoFileRoot = defaultLCOWOSBootFilesPath() + } else if uvm.HasConfidentialPolicy() && uvm.OS() == "windows" { + policy = uvm.createOpts.(*OptionsWCOW).SecurityPolicy + enforcer = uvm.createOpts.(*OptionsWCOW).SecurityPolicyEnforcer + referenceInfoFilePath = uvm.createOpts.(*OptionsWCOW).UVMReferenceInfoFile + } copts := []ConfidentialUVMOpt{ - WithSecurityPolicy(uvm.confidentialUVMOptions.SecurityPolicy), - WithSecurityPolicyEnforcer(uvm.confidentialUVMOptions.SecurityPolicyEnforcer), - WithUVMReferenceInfo(defaultLCOWOSBootFilesPath(), uvm.confidentialUVMOptions.UVMReferenceInfoFile), + WithSecurityPolicy(policy), + WithSecurityPolicyEnforcer(enforcer), + WithUVMReferenceInfo(referenceInfoFileRoot, referenceInfoFilePath), } if err := uvm.SetConfidentialUVMOptions(ctx, copts...); err != nil { return err } } - if uvm.HasConfidentialPolicy() && uvm.OS() == "windows" { - copts := []WCOWConfidentialUVMOpt{ - WithWCOWSecurityPolicy(uvm.createOpts.(*OptionsWCOW).SecurityPolicy), - WithWCOWSecurityPolicyEnforcer(uvm.createOpts.(*OptionsWCOW).SecurityPolicyEnforcer), - WithWCOWUVMReferenceInfo(uvm.createOpts.(*OptionsWCOW).UVMReferenceInfoFile), - } - if err := uvm.SetWCOWConfidentialUVMOptions(ctx, copts...); err != nil { - return err - } - } - return nil } diff --git a/pkg/securitypolicy/api.rego b/pkg/securitypolicy/api.rego index 36a197ebc2..e7bc653ac4 100644 --- a/pkg/securitypolicy/api.rego +++ b/pkg/securitypolicy/api.rego @@ -3,22 +3,24 @@ package api version := "@@API_VERSION@@" enforcement_points := { - "mount_device": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, - "mount_overlay": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, - "mount_cims": {"introducedVersion": "0.11.0", "default_results": {"allowed": false}}, - "create_container": {"introducedVersion": "0.1.0", "default_results": {"allowed": false, "env_list": null, "allow_stdio_access": false}}, - "unmount_device": {"introducedVersion": "0.2.0", "default_results": {"allowed": true}}, - "unmount_overlay": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "exec_in_container": {"introducedVersion": "0.2.0", "default_results": {"allowed": true, "env_list": null}}, - "exec_external": {"introducedVersion": "0.3.0", "default_results": {"allowed": true, "env_list": null, "allow_stdio_access": false}}, - "shutdown_container": {"introducedVersion": "0.4.0", "default_results": {"allowed": true}}, - "signal_container_process": {"introducedVersion": "0.5.0", "default_results": {"allowed": true}}, - "plan9_mount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "plan9_unmount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "get_properties": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}}, - "dump_stacks": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}}, - "runtime_logging": {"introducedVersion": "0.8.0", "default_results": {"allowed": true}}, - "load_fragment": {"introducedVersion": "0.9.0", "default_results": {"allowed": false, "add_module": false}}, - "scratch_mount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}}, - "scratch_unmount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}}, + "mount_device": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}, "use_framework": false}, + "rw_mount_device": {"introducedVersion": "0.11.0", "default_results": {}, "use_framework": true}, + "mount_overlay": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}, "use_framework": false}, + "mount_cims": {"introducedVersion": "0.11.0", "default_results": {"allowed": false}, "use_framework": false}, + "create_container": {"introducedVersion": "0.1.0", "default_results": {"allowed": false, "env_list": null, "allow_stdio_access": false}, "use_framework": false}, + "unmount_device": {"introducedVersion": "0.2.0", "default_results": {"allowed": true}, "use_framework": false}, + "rw_unmount_device": {"introducedVersion": "0.11.0", "default_results": {}, "use_framework": true}, + "unmount_overlay": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "exec_in_container": {"introducedVersion": "0.2.0", "default_results": {"allowed": true, "env_list": null}, "use_framework": false}, + "exec_external": {"introducedVersion": "0.3.0", "default_results": {"allowed": true, "env_list": null, "allow_stdio_access": false}, "use_framework": false}, + "shutdown_container": {"introducedVersion": "0.4.0", "default_results": {"allowed": true}, "use_framework": false}, + "signal_container_process": {"introducedVersion": "0.5.0", "default_results": {"allowed": true}, "use_framework": false}, + "plan9_mount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "plan9_unmount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "get_properties": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}, "use_framework": false}, + "dump_stacks": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}, "use_framework": false}, + "runtime_logging": {"introducedVersion": "0.8.0", "default_results": {"allowed": true}, "use_framework": false}, + "load_fragment": {"introducedVersion": "0.9.0", "default_results": {"allowed": false, "add_module": false}, "use_framework": false}, + "scratch_mount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, + "scratch_unmount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, } diff --git a/pkg/securitypolicy/api_test.rego b/pkg/securitypolicy/api_test.rego index 2d2de733c6..767c506e58 100644 --- a/pkg/securitypolicy/api_test.rego +++ b/pkg/securitypolicy/api_test.rego @@ -3,8 +3,8 @@ package api version := "0.0.2" enforcement_points := { - "__fixture_for_future_test__": {"introducedVersion": "100.0.0", "default_results": {"allowed": true}}, - "__fixture_for_allowed_test_true__": {"introducedVersion": "0.0.2", "default_results": {"allowed": true}}, - "__fixture_for_allowed_test_false__": {"introducedVersion": "0.0.2", "default_results": {"allowed": false}}, - "__fixture_for_allowed_extra__": {"introducedVersion": "0.0.1", "default_results": {"allowed": false, "__test__": "test"}} + "__fixture_for_future_test__": {"introducedVersion": "100.0.0", "default_results": {"allowed": true}, "use_framework": false}, + "__fixture_for_allowed_test_true__": {"introducedVersion": "0.0.2", "default_results": {"allowed": true}, "use_framework": false}, + "__fixture_for_allowed_test_false__": {"introducedVersion": "0.0.2", "default_results": {"allowed": false}, "use_framework": false}, + "__fixture_for_allowed_extra__": {"introducedVersion": "0.0.1", "default_results": {"allowed": false, "__test__": "test"}, "use_framework": false} } diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index 8a28f3e312..ca6721c5dc 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -5,10 +5,28 @@ import future.keywords.in version := "@@FRAMEWORK_VERSION@@" +# Add ^ and $ to regex patterns that doesn't have them. +# This forces the regex to match the entire string, which is safer. +# Policies should include .* explicitly at the beginning or end if partial +# matches are to be allowed. + +anchor_pattern(p) := p { + startswith(p, "^") + endswith(p, "$") +} else := concat("", ["^", p]) { + endswith(p, "$") +} else := concat("", [p, "$"]) { + startswith(p, "^") +} else := concat("", ["^", p, "$"]) + device_mounted(target) { data.metadata.devices[target] } +device_mounted(target) { + data.metadata.rw_devices[target] +} + default deviceHash_ok := false # test if a device hash exists as a layer in a policy container @@ -27,9 +45,14 @@ deviceHash_ok { default mount_device := {"allowed": false} +mount_target_ok { + regex.match(anchor_pattern(input.mountPathRegex), input.target) +} + mount_device := {"metadata": [addDevice], "allowed": true} { not device_mounted(input.target) deviceHash_ok + mount_target_ok addDevice := { "name": "devices", "action": "add", @@ -38,10 +61,38 @@ mount_device := {"metadata": [addDevice], "allowed": true} { } } +allowed_scratch_fs("ext4") +allowed_scratch_fs("xfs") + +rwmount_device_encrypt_ok { + input.encrypted +} + +rwmount_device_encrypt_ok { + allow_unencrypted_scratch +} + +default rw_mount_device := {"allowed": false} + +rw_mount_device := {"metadata": [addDevice], "allowed": true} { + not device_mounted(input.target) + rwmount_device_encrypt_ok + input.ensureFilesystem + allowed_scratch_fs(input.filesystem) + mount_target_ok + addDevice := { + "name": "rw_devices", + "action": "add", + "key": input.target, + "value": true, + } +} + default unmount_device := {"allowed": false} unmount_device := {"metadata": [removeDevice], "allowed": true} { - device_mounted(input.unmountTarget) + data.metadata.devices[input.unmountTarget] + removeDevice := { "name": "devices", "action": "remove", @@ -49,6 +100,18 @@ unmount_device := {"metadata": [removeDevice], "allowed": true} { } } +default rw_unmount_device := {"allowed": false} + +rw_unmount_device := {"metadata": [removeRWDevice], "allowed": true} { + data.metadata.rw_devices[input.unmountTarget] + + removeRWDevice := { + "name": "rw_devices", + "action": "remove", + "key": input.unmountTarget, + } +} + layerPaths_ok(layers) { length := count(layers) count(input.layerPaths) == length @@ -127,6 +190,10 @@ default mount_overlay := {"allowed": false} mount_overlay := {"metadata": [addMatches, addOverlayTarget], "allowed": true} { not overlay_exists + # sanity check, but due to checks in the Go code, this should always pass if + # `not overlay_exists` passes. + not overlay_mounted(input.target) + containers := [container | container := candidate_containers[_] layerPaths_ok(container.layers) @@ -171,30 +238,7 @@ env_ok(pattern, "string", value) { } env_ok(pattern, "re2", value) { - anchored := anchor_pattern(pattern) - regex.match(anchored, value) -} - -anchor_pattern(p) := anchored { - startswith_leading := startswith(p, "^") - endswith_trailing := endswith(p, "$") - - anchored = sprintf("%s%s%s", [ - add_leading_trailing_chars(startswith_leading, "", "^"), # Add ^ only if missing - p, - add_leading_trailing_chars(endswith_trailing, "", "$") # Add $ only if missing - ]) -} - -# Function to return one of two values depending on a boolean condition -add_leading_trailing_chars(cond, ifTrue, ifFalse) := result { - cond - result = ifTrue -} - -add_leading_trailing_chars(cond, ifTrue, ifFalse) := result { - not cond - result = ifFalse + regex.match(anchor_pattern(pattern), value) } rule_ok(rule, env) { @@ -316,7 +360,7 @@ idName_ok(pattern, "name", value) { } idName_ok(pattern, "re2", value) { - regex.match(pattern, value.name) + regex.match(anchor_pattern(pattern), value.name) } user_ok(user) { @@ -682,13 +726,13 @@ security_ok(current_container) { mountSource_ok(constraint, source) { startswith(constraint, data.sandboxPrefix) newConstraint := replace(constraint, data.sandboxPrefix, input.sandboxDir) - regex.match(newConstraint, source) + regex.match(anchor_pattern(newConstraint), source) } mountSource_ok(constraint, source) { startswith(constraint, data.hugePagesPrefix) newConstraint := replace(constraint, data.hugePagesPrefix, input.hugePagesDir) - regex.match(newConstraint, source) + regex.match(anchor_pattern(newConstraint), source) } mountSource_ok(constraint, source) { @@ -857,7 +901,7 @@ exec_in_container := {"metadata": [updateMatches], default shutdown_container := {"allowed": false} -shutdown_container := {"started": remove, "metadata": [remove], "allowed": true} { +shutdown_container := {"metadata": [remove], "allowed": true} { container_started remove := { "name": "matches", @@ -918,7 +962,7 @@ default plan9_mount := {"allowed": false} plan9_mount := {"metadata": [addPlan9Target], "allowed": true} { not plan9_mounted(input.target) some containerID, _ in data.metadata.matches - pattern := concat("", [input.rootPrefix, "/", containerID, input.mountPathPrefix]) + pattern := concat("", ["^", input.rootPrefix, "/", containerID, input.mountPathPrefix, "$"]) regex.match(pattern, input.target) addPlan9Target := { "name": "p9mounts", @@ -940,20 +984,28 @@ plan9_unmount := {"metadata": [removePlan9Target], "allowed": true} { } -default enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": true, "invalid": false, "version_missing": false} +default enforcement_point_info := { + "available": false, + "default_results": {"allow": false}, + "unknown": true, + "invalid": false, + "version_missing": false, + "use_framework": false +} -enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": false, "version_missing": true} { +enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": false, "version_missing": true, "use_framework": false} { policy_api_version == null } -enforcement_point_info := {"available": available, "default_results": default_results, "unknown": false, "invalid": false, "version_missing": false} { +enforcement_point_info := {"available": available, "default_results": default_results, "unknown": false, "invalid": false, "version_missing": false, "use_framework": use_framework} { enforcement_point := data.api.enforcement_points[input.name] semver.compare(data.api.version, enforcement_point.introducedVersion) >= 0 available := semver.compare(policy_api_version, enforcement_point.introducedVersion) >= 0 default_results := enforcement_point.default_results + use_framework := enforcement_point.use_framework } -enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": true, "version_missing": false} { +enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": true, "version_missing": false, "use_framework": false} { enforcement_point := data.api.enforcement_points[input.name] semver.compare(data.api.version, enforcement_point.introducedVersion) < 0 } @@ -1157,8 +1209,6 @@ candidate_fragments := fragments { fragments := array.concat(policy_fragments, fragment_fragments) } -default load_fragment := {"allowed": false} - svn_ok(svn, minimum_svn) { # deprecated semver.is_valid(svn) @@ -1170,15 +1220,32 @@ svn_ok(svn, minimum_svn) { to_number(svn) >= to_number(minimum_svn) } -fragment_ok(fragment) { +fragment_issuer_feed_ok(fragment) { input.issuer == fragment.issuer input.feed == fragment.feed - svn_ok(data[input.namespace].svn, fragment.minimum_svn) +} + +default load_fragment := {"allowed": false} + +# load_fragment gets called twice - first before loading the fragment as a Rego +# module, with input.fragment_loaded set to false, in which case we do not yet +# have access to anything under data[fragment.namespace] yet, and so we only +# check that the fragment issuer and feed is valid, but does not actually load +# the fragment into metadata. It will then be called a second time, at which +# point we can check the SVN defined in the fragment is valid, and if +# successful, add the fragment to the metadata. + +load_fragment := {"allowed": true} { + not input.fragment_loaded + some fragment in candidate_fragments + fragment_issuer_feed_ok(fragment) } load_fragment := {"metadata": [updateIssuer], "add_module": add_module, "allowed": true} { + input.fragment_loaded some fragment in candidate_fragments - fragment_ok(fragment) + fragment_issuer_feed_ok(fragment) + svn_ok(data[input.namespace].svn, fragment.minimum_svn) issuer := update_issuer(fragment.includes) updateIssuer := { @@ -1246,13 +1313,54 @@ errors["deviceHash not found"] { } errors["device already mounted at path"] { - input.rule == "mount_device" + input.rule in ["mount_device", "rw_mount_device"] device_mounted(input.target) } +errors["mountpoint invalid"] { + input.rule in ["mount_device", "rw_mount_device"] + not mount_target_ok +} + errors["no device at path to unmount"] { input.rule == "unmount_device" - not device_mounted(input.unmountTarget) + not data.metadata.devices[input.unmountTarget] + not data.metadata.rw_devices[input.unmountTarget] +} + +errors["received read-only unmount request, but device provided is read-write"] { + input.rule == "unmount_device" + not data.metadata.devices[input.unmountTarget] + data.metadata.rw_devices[input.unmountTarget] +} + +errors["no device at path to unmount"] { + input.rule == "rw_unmount_device" + not data.metadata.devices[input.unmountTarget] + not data.metadata.rw_devices[input.unmountTarget] +} + +errors["received read-write unmount request, but device provided is read-only"] { + input.rule == "rw_unmount_device" + not data.metadata.rw_devices[input.unmountTarget] + data.metadata.devices[input.unmountTarget] +} + +# Error string tested in azcri-containerd Test_RunPodSandboxNotAllowed_WithPolicy_EncryptedScratchPolicy +errors["unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption"] { + input.rule == "rw_mount_device" + not allow_unencrypted_scratch + not input.encrypted +} + +errors["ensureFilesystem must be set on rw device mounts"] { + input.rule == "rw_mount_device" + not input.ensureFilesystem +} + +errors["rw device mounts uses a filesystem that is not allowed"] { + input.rule == "rw_mount_device" + not allowed_scratch_fs(input.filesystem) } errors["container already started"] { @@ -1548,6 +1656,7 @@ default fragment_version_is_valid := false fragment_version_is_valid { some fragment in candidate_fragments + input.fragment_loaded fragment.issuer == input.issuer fragment.feed == input.feed svn_ok(data[input.namespace].svn, fragment.minimum_svn) @@ -1559,6 +1668,7 @@ svn_mismatch { some fragment in candidate_fragments fragment.issuer == input.issuer fragment.feed == input.feed + input.fragment_loaded to_number(data[input.namespace].svn) semver.is_valid(fragment.minimum_svn) } @@ -1567,6 +1677,7 @@ svn_mismatch { some fragment in candidate_fragments fragment.issuer == input.issuer fragment.feed == input.feed + input.fragment_loaded semver.is_valid(data[input.namespace].svn) to_number(fragment.minimum_svn) } @@ -1574,6 +1685,7 @@ svn_mismatch { errors["fragment svn is below the specified minimum"] { input.rule == "load_fragment" fragment_feed_matches + input.fragment_loaded not svn_mismatch not fragment_version_is_valid } @@ -1581,6 +1693,7 @@ errors["fragment svn is below the specified minimum"] { errors["fragment svn and the specified minimum are different types"] { input.rule == "load_fragment" fragment_feed_matches + input.fragment_loaded svn_mismatch } @@ -1611,12 +1724,16 @@ errors[framework_version_error] { } errors[fragment_framework_version_error] { + input.rule == "load_fragment" + input.fragment_loaded input.namespace fragment_framework_version == null fragment_framework_version_error := concat(" ", ["fragment framework_version is missing. Current version:", version]) } errors[fragment_framework_version_error] { + input.rule == "load_fragment" + input.fragment_loaded input.namespace semver.compare(fragment_framework_version, version) > 0 fragment_framework_version_error := concat(" ", ["fragment framework_version is ahead of the current version:", fragment_framework_version, "is greater than", version]) diff --git a/pkg/securitypolicy/open_door.rego b/pkg/securitypolicy/open_door.rego index a8e283092d..23c35f9b04 100644 --- a/pkg/securitypolicy/open_door.rego +++ b/pkg/securitypolicy/open_door.rego @@ -3,10 +3,12 @@ package policy api_version := "@@API_VERSION@@" mount_device := {"allowed": true} +rw_mount_device := {"allowed": true} mount_overlay := {"allowed": true} create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} mount_cims := {"allowed": true} unmount_device := {"allowed": true} +rw_unmount_device := {"allowed": true} unmount_overlay := {"allowed": true} exec_in_container := {"allowed": true, "env_list": null} exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} diff --git a/pkg/securitypolicy/policy.rego b/pkg/securitypolicy/policy.rego index 9414116c19..03a71094bd 100644 --- a/pkg/securitypolicy/policy.rego +++ b/pkg/securitypolicy/policy.rego @@ -6,7 +6,9 @@ framework_version := "@@FRAMEWORK_VERSION@@" @@OBJECTS@@ mount_device := data.framework.mount_device +rw_mount_device := data.framework.rw_mount_device unmount_device := data.framework.unmount_device +rw_unmount_device := data.framework.rw_unmount_device mount_overlay := data.framework.mount_overlay unmount_overlay := data.framework.unmount_overlay mount_cims:= data.framework.mount_cims diff --git a/pkg/securitypolicy/policy_v0.10.0_api_test.rego b/pkg/securitypolicy/policy_v0.10.0_api_test.rego new file mode 100644 index 0000000000..407c3ee8ff --- /dev/null +++ b/pkg/securitypolicy/policy_v0.10.0_api_test.rego @@ -0,0 +1,84 @@ +package policy + +api_version := "0.10.0" +framework_version := "0.3.0" + +fragments := [ + { + "feed": "@@FRAGMENT_FEED@@", + "includes": [ + "containers", + "fragments" + ], + "issuer": "@@FRAGMENT_ISSUER@@", + "minimum_svn": "0" + } +] + + +containers := [ + { + "allow_elevated": false, + "allow_stdio_access": true, + "capabilities": { + "ambient": [], + "bounding": [], + "effective": [], + "inheritable": [], + "permitted": [] + }, + "command": [ "bash" ], + "env_rules": [], + "exec_processes": [], + "layers": [ + "@@CONTAINER_LAYER_HASH@@", + ], + "mounts": [], + "no_new_privileges": false, + "seccomp_profile_sha256": "", + "signals": [], + "user": { + "group_idnames": [ + { + "pattern": "", + "strategy": "any" + } + ], + "umask": "0022", + "user_idname": { + "pattern": "", + "strategy": "any" + } + }, + "working_dir": "/" + } +] + +allow_properties_access := true +allow_dump_stacks := false +allow_runtime_logging := false +allow_environment_variable_dropping := true +allow_unencrypted_scratch := false +allow_capability_dropping := true + +mount_device := data.framework.mount_device +unmount_device := data.framework.unmount_device +mount_overlay := data.framework.mount_overlay +unmount_overlay := data.framework.unmount_overlay +create_container := data.framework.create_container +exec_in_container := data.framework.exec_in_container +exec_external := {"allowed": true, + "allow_stdio_access": true, + "env_list": input.envList} +shutdown_container := data.framework.shutdown_container +signal_container_process := data.framework.signal_container_process +plan9_mount := data.framework.plan9_mount +plan9_unmount := data.framework.plan9_unmount +get_properties := data.framework.get_properties +dump_stacks := data.framework.dump_stacks +runtime_logging := data.framework.runtime_logging +load_fragment := data.framework.load_fragment +scratch_mount := data.framework.scratch_mount +scratch_unmount := data.framework.scratch_unmount + +reason := {"errors": data.framework.errors} diff --git a/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego b/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego new file mode 100644 index 0000000000..dccdba0dec --- /dev/null +++ b/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego @@ -0,0 +1,22 @@ +package policy + +api_version := "0.10.0" +framework_version := "0.3.0" + +mount_device := {"allowed": true} +mount_overlay := {"allowed": true} +create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} +unmount_device := {"allowed": true} +unmount_overlay := {"allowed": true} +exec_in_container := {"allowed": true, "env_list": null} +exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} +shutdown_container := {"allowed": true} +signal_container_process := {"allowed": true} +plan9_mount := {"allowed": true} +plan9_unmount := {"allowed": true} +get_properties := {"allowed": true} +dump_stacks := {"allowed": true} +runtime_logging := {"allowed": true} +load_fragment := {"allowed": true} +scratch_mount := {"allowed": true} +scratch_unmount := {"allowed": true} diff --git a/internal/pspdriver/pspdriver.go b/pkg/securitypolicy/pspdriver.go similarity index 95% rename from internal/pspdriver/pspdriver.go rename to pkg/securitypolicy/pspdriver.go index db41384853..869eaeb04c 100644 --- a/internal/pspdriver/pspdriver.go +++ b/pkg/securitypolicy/pspdriver.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package pspdriver +package securitypolicy import ( "bytes" @@ -217,7 +217,7 @@ func GetPspDriverError() error { } // IsSNPMode() returns true if it's in SNP mode. -func IsSNPMode(ctx context.Context) (bool, error) { +func IsSNPMode() (bool, error) { if pspDriverError != nil { return false, pspDriverError @@ -249,7 +249,7 @@ func IsSNPMode(ctx context.Context) (bool, error) { } // FetchRawSNPReport returns attestation report bytes. -func FetchRawSNPReport(ctx context.Context, reportData []byte) ([]byte, error) { +func FetchRawSNPReport(reportData []byte) ([]byte, error) { if pspDriverError != nil { return nil, pspDriverError } @@ -291,8 +291,8 @@ func FetchRawSNPReport(ctx context.Context, reportData []byte) ([]byte, error) { } // FetchParsedSNPReport parses raw attestation response into proper structs. -func FetchParsedSNPReport(ctx context.Context, reportData []byte) (Report, error) { - rawBytes, err := FetchRawSNPReport(ctx, reportData) +func FetchParsedSNPReport(reportData []byte) (Report, error) { + rawBytes, err := FetchRawSNPReport(reportData) if err != nil { return Report{}, err } @@ -308,16 +308,16 @@ func FetchParsedSNPReport(ctx context.Context, reportData []byte) (Report, error // TODO: Based on internal\guest\runtime\hcsv2\hostdata.go and it's duplicated. // ValidateHostData fetches SNP report (if applicable) and validates `hostData` against // HostData set at UVM launch. -func ValidateHostData(ctx context.Context, hostData []byte) error { +func ValidateHostDataPSP(hostData []byte) error { // If the UVM is not SNP, then don't try to fetch an SNP report. - isSnpMode, err := IsSNPMode(ctx) + isSnpMode, err := IsSNPMode() if err != nil { return err } if !isSnpMode { return nil } - report, err := FetchParsedSNPReport(ctx, nil) + report, err := FetchParsedSNPReport(nil) if err != nil { return err } diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index dbe016098d..3adbb6a2b6 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -6,6 +6,7 @@ package securitypolicy import ( "context" _ "embed" + "encoding/hex" "encoding/json" "fmt" "math/rand" @@ -15,6 +16,7 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "syscall" "testing" "time" @@ -34,7 +36,6 @@ const ( maxExternalProcessesInGeneratedConstraints = 16 maxFragmentsInGeneratedConstraints = 4 maxGeneratedExternalProcesses = 12 - maxGeneratedSandboxIDLength = 32 maxGeneratedEnforcementPointLength = 64 maxGeneratedPlan9Mounts = 8 maxGeneratedFragmentFeedLength = 256 @@ -46,7 +47,6 @@ const ( minStringLength = 10 maxContainersInGeneratedConstraints = 32 maxLayersInGeneratedContainer = 32 - maxGeneratedContainerID = 1000000 maxGeneratedCommandLength = 128 maxGeneratedCommandArgs = 12 maxGeneratedEnvironmentVariables = 16 @@ -347,23 +347,40 @@ type regoPlan9MountTestConfig struct { } func mountImageForContainer(policy *regoEnforcer, container *securityPolicyContainer) (string, error) { - ctx := context.Background() containerID := testDataGenerator.uniqueContainerID() + if err := mountImageForContainerWithID(policy, container, containerID); err != nil { + return "", err + } + return containerID, nil +} + +func mountImageForContainerWithID(policy *regoEnforcer, container *securityPolicyContainer, containerID string) error { + ctx := context.Background() layerPaths, err := testDataGenerator.createValidOverlayForContainer(policy, container) if err != nil { - return "", fmt.Errorf("error creating valid overlay: %w", err) + return fmt.Errorf("error creating valid overlay: %w", err) } + scratchDisk := getScratchDiskMountTarget(containerID) + err = policy.EnforceRWDeviceMountPolicy(ctx, scratchDisk, true, true, "xfs") + if err != nil { + return fmt.Errorf("error mounting scratch disk: %w", err) + } + + overlayTarget := getOverlayMountTarget(containerID) + // see NOTE_TESTCOPY - err = policy.EnforceOverlayMountPolicy(ctx, containerID, copyStrings(layerPaths), testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + ctx, containerID, copyStrings(layerPaths), overlayTarget) if err != nil { - return "", fmt.Errorf("error mounting filesystem: %w", err) + return fmt.Errorf("error mounting filesystem: %w", err) } - return containerID, nil + return nil } + func buildMountSpecFromMountArray(mounts []mountInternal, sandboxID string, r *rand.Rand) *oci.Spec { mountSpec := new(oci.Spec) @@ -1333,7 +1350,8 @@ func selectFragmentsFromConstraints(gc *generatedConstraints, numFragments int, } func generateSandboxID(r *rand.Rand) string { - return randVariableString(r, maxGeneratedSandboxIDLength) + // Sandbox IDs has the same format as container IDs + return generateContainerID(r) } func generateEnforcementPoint(r *rand.Rand) string { @@ -1394,6 +1412,10 @@ func setupRegoCreateContainerTest(gc *generatedConstraints, testContainer *secur return nil, err } + return createTestContainerSpec(gc, containerID, testContainer, privilegedError, policy, defaultMounts, privilegedMounts) +} + +func createTestContainerSpec(gc *generatedConstraints, containerID string, testContainer *securityPolicyContainer, privilegedError bool, policy *regoEnforcer, defaultMounts, privilegedMounts []mountInternal) (*regoContainerTestConfig, error) { envList := buildEnvironmentVariablesFromEnvRules(testContainer.EnvRules, testRand) sandboxID := testDataGenerator.uniqueSandboxID() @@ -1615,6 +1637,20 @@ func copyStrings(values []string) []string { //go:embed api_test.rego var apiTestCode string +//go:embed policy_v0.10.0_api_test.rego +var policyWith_0_10_0_apiTestCode string + +//go:embed policy_v0.10.0_api_test_allow_all.rego +var policyWith_0_10_0_apiTestAllowAllCode string + +func getPolicyCode_0_10_0(layerHash, fragmentIssuer, fragmentFeed string) string { + s := policyWith_0_10_0_apiTestCode + s = strings.Replace(s, "@@CONTAINER_LAYER_HASH@@", layerHash, 1) + s = strings.Replace(s, "@@FRAGMENT_ISSUER@@", fragmentIssuer, 1) + s = strings.Replace(s, "@@FRAGMENT_FEED@@", fragmentFeed, 1) + return s +} + func (p *regoEnforcer) injectTestAPI() error { p.rego.RemoveModule("api.rego") p.rego.AddModule("api.rego", &rpi.RegoModule{Namespace: "api", Code: apiTestCode}) @@ -2030,7 +2066,7 @@ func assertDecisionJSONContains(t *testing.T, err error, expectedValues ...strin for _, expected := range expectedValues { if !strings.Contains(policyDecision, expected) { - t.Errorf("expected error to contain %q", expected) + t.Errorf("expected error to contain %q, but got %q", expected, policyDecision) return false } } @@ -2492,7 +2528,6 @@ func buildEnvironmentVariablesFromEnvRules(rules []EnvRuleConfig, r *rand.Rand) // Build in all required rules, this isn't a setup method of "missing item" // tests for _, rule := range rules { - if rule.Required { if rule.Strategy != EnvVarRuleRegex { vars = append(vars, rule.Rule) @@ -2529,12 +2564,14 @@ func buildEnvironmentVariablesFromEnvRules(rules []EnvRuleConfig, r *rand.Rand) usedIndexes[anIndex] = struct{}{} } numberOfMatches-- - } return vars } +// Only used for random mount targets or for the standard enforcer. Rego policy +// enforces proper targets that are e.g. created from +// guestpath.LCOWGlobalScsiMountPrefixFmt func generateMountTarget(r *rand.Rand) string { return randVariableString(r, maxGeneratedMountTargetLength) } @@ -2563,8 +2600,12 @@ func selectRootHashFromConstraints(constraints *generatedConstraints, r *rand.Ra } func generateContainerID(r *rand.Rand) string { - id := atLeastOneAtMost(r, maxGeneratedContainerID) - return strconv.FormatInt(int64(id), 10) + idbytes := make([]byte, 32) + _, err := r.Read(idbytes) + if err != nil { + panic(fmt.Errorf("failed to generate random container ID: %w", err)) + } + return hex.EncodeToString(idbytes) } func generateMounts(r *rand.Rand) []mountInternal { @@ -2654,26 +2695,28 @@ func selectContainerFromContainerList(containers []*securityPolicyContainer, r * } type dataGenerator struct { - rng *rand.Rand - mountTargets stringSet - containerIDs stringSet - sandboxIDs stringSet - enforcementPoints stringSet - fragmentIssuers stringSet - fragmentFeeds stringSet - fragmentNamespaces stringSet + rng *rand.Rand + layerMountTarget stringSet + nextLayerMountTarget atomic.Uint64 + containerIDs stringSet + sandboxIDs stringSet + enforcementPoints stringSet + fragmentIssuers stringSet + fragmentFeeds stringSet + fragmentNamespaces stringSet } func newDataGenerator(rng *rand.Rand) *dataGenerator { return &dataGenerator{ - rng: rng, - mountTargets: make(stringSet), - containerIDs: make(stringSet), - sandboxIDs: make(stringSet), - enforcementPoints: make(stringSet), - fragmentIssuers: make(stringSet), - fragmentFeeds: make(stringSet), - fragmentNamespaces: make(stringSet), + rng: rng, + layerMountTarget: make(stringSet), + nextLayerMountTarget: atomic.Uint64{}, + containerIDs: make(stringSet), + sandboxIDs: make(stringSet), + enforcementPoints: make(stringSet), + fragmentIssuers: make(stringSet), + fragmentFeeds: make(stringSet), + fragmentNamespaces: make(stringSet), } } @@ -2687,21 +2730,36 @@ func (s *stringSet) randUnique(r *rand.Rand, generator func(*rand.Rand) string) } } -func (gen *dataGenerator) uniqueMountTarget() string { - return gen.mountTargets.randUnique(gen.rng, generateMountTarget) +// Generate a purely random mount target. This will be rejected by rego. +func (gen *dataGenerator) uniqueRandomMountTarget() string { + return gen.layerMountTarget.randUnique(gen.rng, generateMountTarget) } func (gen *dataGenerator) uniqueContainerID() string { return gen.containerIDs.randUnique(gen.rng, generateContainerID) } +func (gen *dataGenerator) uniqueLayerMountTarget() string { + idx := gen.nextLayerMountTarget.Add(1) + return fmt.Sprintf(guestpath.LCOWGlobalScsiMountPrefixFmt, idx) +} + +func getScratchDiskMountTarget(containerID string) string { + return path.Join(guestpath.LCOWRootPrefixInUVM, containerID) +} + +// Returns the roofs of a container. +func getOverlayMountTarget(containerID string) string { + return path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) +} + func (gen *dataGenerator) createValidOverlayForContainer(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { ctx := context.Background() // storage for our mount paths overlay := make([]string, len(container.Layers)) for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2714,14 +2772,16 @@ func (gen *dataGenerator) createValidOverlayForContainer(enforcer SecurityPolicy } func (gen *dataGenerator) createInvalidOverlayForContainer(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { - method := gen.rng.Intn(3) + method := gen.rng.Intn(4) switch method { case 0: return gen.invalidOverlaySameSizeWrongMounts(enforcer, container) case 1: return gen.invalidOverlayCorrectDevicesWrongOrderSomeMissing(enforcer, container) - default: + case 2: return gen.invalidOverlayRandomJunk(enforcer, container) + default: + return gen.invalidOverlayRandomNoMount(enforcer, container) } } @@ -2731,14 +2791,14 @@ func (gen *dataGenerator) invalidOverlaySameSizeWrongMounts(enforcer SecurityPol overlay := make([]string, len(container.Layers)) for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err } // generate a random new mount point to cause an error - overlay[len(overlay)-i-1] = gen.uniqueMountTarget() + overlay[len(overlay)-i-1] = gen.uniqueLayerMountTarget() } return overlay, nil @@ -2754,7 +2814,7 @@ func (gen *dataGenerator) invalidOverlayCorrectDevicesWrongOrderSomeMissing(enfo var overlay []string for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2775,12 +2835,12 @@ func (gen *dataGenerator) invalidOverlayRandomJunk(enforcer SecurityPolicyEnforc overlay := make([]string, layersToCreate) for i := 0; i < int(layersToCreate); i++ { - overlay[i] = gen.uniqueMountTarget() + overlay[i] = generateMountTarget(gen.rng) } // setup entirely different and "correct" expected mounting for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2790,6 +2850,17 @@ func (gen *dataGenerator) invalidOverlayRandomJunk(enforcer SecurityPolicyEnforc return overlay, nil } +func (gen *dataGenerator) invalidOverlayRandomNoMount(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { + layersToCreate := gen.rng.Int31n(maxLayersInGeneratedContainer) + overlay := make([]string, layersToCreate) + + for i := 0; i < int(layersToCreate); i++ { + overlay[i] = gen.uniqueLayerMountTarget() + } + + return overlay, nil +} + func randVariableString(r *rand.Rand, maxLen int32) string { return randString(r, atLeastOneAtMost(r, maxLen)) } @@ -2935,3 +3006,19 @@ type containerInitProcess struct { WorkingDir string AllowStdioAccess bool } + +func startRevertableSection(t *testing.T, policy *regoEnforcer) RevertableSectionHandle { + rev, err := policy.StartRevertableSection() + if err != nil { + t.Fatalf("Failed to start revertable section: %v", err) + } + return rev +} + +func commitOrRollback(rev RevertableSectionHandle, shouldCommit bool) { + if shouldCommit { + rev.Commit() + } else { + rev.Rollback() + } +} diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 94cfd8d031..0ddd933bac 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand" "os" + "path" "path/filepath" "slices" "strconv" @@ -106,7 +107,7 @@ func Test_MarshalRego_Policy(t *testing.T) { _, err = newRegoPolicy(expected, defaultMounts, privilegedMounts, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } @@ -193,11 +194,11 @@ func Test_Rego_EnforceDeviceMountPolicy_No_Matches(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := generateInvalidRootHash(testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -219,11 +220,11 @@ func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -237,7 +238,7 @@ func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { } } -func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { +func Test_Rego_EnforceDeviceUnmountPolicy_Removes_Device_Entries(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) @@ -247,7 +248,7 @@ func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -272,7 +273,36 @@ func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { } if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries failed: %v", err) + t.Errorf("Test_Rego_EnforceDeviceUnmountPolicy_Removes_Device_Entries failed: %v", err) + } +} + +func Test_Rego_EnforceDeviceUnmountPolicy_No_Matches(t *testing.T) { + f := func(p *generatedConstraints) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Error(err) + return false + } + + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.EnforceDeviceUnmountPolicy(p.ctx, target) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + target = getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, target) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceDeviceUnmountPolicy_No_Matches failed: %v", err) } } @@ -282,11 +312,11 @@ func Test_Rego_EnforceDeviceMountPolicy_Duplicate_Device_Target(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) if err != nil { @@ -309,6 +339,331 @@ func Test_Rego_EnforceDeviceMountPolicy_Duplicate_Device_Target(t *testing.T) { } } +func Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget(t *testing.T) { + f := func(p *generatedConstraints) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := testDataGenerator.uniqueRandomMountTarget() + rootHash := selectRootHashFromConstraints(p, testRand) + + err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) + + return assertDecisionJSONContains(t, err, "mountpoint invalid") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget failed: %v", err) + } +} + +func Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget_PathTraversal(t *testing.T) { + p := generateConstraints(testRand, 1) + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := testDataGenerator.uniqueLayerMountTarget() + "/../../../../.." + rootHash := selectRootHashFromConstraints(p, testRand) + + err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) + + assertDecisionJSONContains(t, err, "mountpoint invalid") +} + +func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoEnforcer, mountScratchFirst, unmountScratchFirst, testDenials bool) bool { + container := selectContainerFromContainerList(p.containers, testRand) + containerID := testDataGenerator.uniqueContainerID() + rotarget := testDataGenerator.uniqueLayerMountTarget() + rwtarget := getScratchDiskMountTarget(containerID) + + var err error + + mountScratch := func() bool { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, rwtarget, true, true, "xfs") + if err != nil { + t.Errorf("unable to mount rw device: %v", err) + return false + } + return true + } + + mountLayer := func() bool { + err = policy.EnforceDeviceMountPolicy(p.ctx, rotarget, container.Layers[0]) + if err != nil { + t.Errorf("unable to mount ro device: %v", err) + return false + } + return true + } + + if mountScratchFirst { + if !mountScratch() || !mountLayer() { + return false + } + } else { + if !mountLayer() || !mountScratch() { + return false + } + } + + if testDenials { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, rwtarget, true, true, "xfs") + if !assertDecisionJSONContains(t, err, "device already mounted at path") { + return false + } + + err = policy.EnforceDeviceMountPolicy(p.ctx, rotarget, container.Layers[0]) + if !assertDecisionJSONContains(t, err, "device already mounted at path") { + return false + } + } + + unmountScratch := func() bool { + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, rwtarget) + if err != nil { + t.Errorf("unable to unmount rw device: %v", err) + return false + } + return true + } + + unmountLayer := func() bool { + err = policy.EnforceDeviceUnmountPolicy(p.ctx, rotarget) + if err != nil { + t.Errorf("unable to unmount ro device: %v", err) + return false + } + return true + } + + if unmountScratchFirst { + if !unmountScratch() || !unmountLayer() { + return false + } + } else { + if !unmountLayer() || !unmountScratch() { + return false + } + } + + if testDenials { + err = policy.EnforceDeviceUnmountPolicy(p.ctx, rotarget) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, rwtarget) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + } + + return true +} + +func Test_Rego_EnforceRWDeviceMountPolicy_MountAndUnmount(t *testing.T) { + f := func(p *generatedConstraints, mountScratchFirst, unmountScratchFirst bool) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + return deviceMountUnmountTest(t, p, policy, mountScratchFirst, unmountScratchFirst, true) + } + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_MountAndUnmount failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_InvalidTarget(t *testing.T) { + f := func(p *generatedConstraints, encrypted bool, ensureFileSystem bool) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := testDataGenerator.uniqueRandomMountTarget() + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + + return assertDecisionJSONContains(t, err, "mountpoint invalid") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_Matches failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_MissingEnsureFilesystem(t *testing.T) { + f := func(p *generatedConstraints, encrypted bool) bool { + p.allowUnencryptedScratch = !encrypted + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, false, filesystem) + + return assertDecisionJSONContains(t, err, "ensureFilesystem must be set on rw device mounts") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 10, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_Matches failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_DontAllowUnencrypted(t *testing.T) { + p := generateConstraints(testRand, 1) + p.allowUnencryptedScratch = false + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, false, true, filesystem) + + assertDecisionJSONContains(t, err, "unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption") +} + +func Test_Rego_EnforceRWDeviceMountPolicy_InvalidFilesystem(t *testing.T) { + p := generateConstraints(testRand, 1) + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + dangerousFilesystems := []string{ + "9p", + "overlay", + "nfs", + "cifs", + } + + for _, filesystem := range dangerousFilesystems { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, true, true, filesystem) + assertDecisionJSONContains(t, err, "rw device mounts uses a filesystem that is not allowed") + } +} + +// Test that for an older allow all policy (api version < 0.11.0) that does not +// have rw_mount_device, the use_framework passthrough is done correctly, +// allowing enforcing rw mounts. +func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0_allow_all(t *testing.T) { + p := generateConstraints(testRand, 1) + regoPolicy := policyWith_0_10_0_apiTestAllowAllCode + for _, b1 := range []bool{true, false} { + for _, b2 := range []bool{true, false} { + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + t.Run(fmt.Sprintf("mountScratchFirst=%t, unmountScratchFirst=%t", b1, b2), func(t *testing.T) { + deviceMountUnmountTest(t, p, policy, b1, b2, false) + }) + } + } +} + +// Test that for an older policy (api version < 0.11.0) that does not have +// rw_mount_device, the use_framework passthrough is done correctly, allowing +// enforcing rw mounts. +func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0(t *testing.T) { + p := generateConstraints(testRand, 1) + regoPolicy := getPolicyCode_0_10_0(p.containers[0].Layers[0], testDataGenerator.uniqueFragmentIssuer(), testDataGenerator.uniqueFragmentFeed()) + for _, b1 := range []bool{true, false} { + for _, b2 := range []bool{true, false} { + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + t.Run(fmt.Sprintf("mountScratchFirst=%t, unmountScratchFirst=%t", b1, b2), func(t *testing.T) { + deviceMountUnmountTest(t, p, policy, b1, b2, true) + }) + } + } + + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + // Invalid mount target + target := testDataGenerator.uniqueRandomMountTarget() + filesystem := "xfs" + encrypted := true + ensureFileSystem := true + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "mountpoint invalid") + + // Missing ensureFilesystem + ensureFileSystem = false + target = getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "ensureFilesystem must be set on rw device mounts") + + // Unencrypted scratch not allowed + ensureFileSystem = true + encrypted = false + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption") +} + +func Test_Rego_EnforceRWDeviceMountPolicy_OpenDoor(t *testing.T) { + p := generateConstraints(testRand, 1) + policy, err := newRegoPolicy(openDoorRego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile open door rego policy: %v", err) + return + } + + deviceMountUnmountTest(t, p, policy, true, true, false) + + ensureFileSystem := false + encrypted := false + filesystem := "zfs" + target := "/bin" + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + if err != nil { + t.Errorf("unexpected error mounting rw device: %v", err) + } + + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, target) + if err != nil { + t.Errorf("unexpected error unmounting rw device: %v", err) + } +} + // Verify that RegoSecurityPolicyEnforcer.EnforceOverlayMountPolicy will // return an error when there's no matching overlay targets. func Test_Rego_EnforceOverlayMountPolicy_No_Matches(t *testing.T) { @@ -319,7 +674,8 @@ func Test_Rego_EnforceOverlayMountPolicy_No_Matches(t *testing.T) { return false } - err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) if err == nil { return false @@ -348,7 +704,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Matches(t *testing.T) { return false } - err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) // getting an error means something is broken return err == nil @@ -388,7 +745,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_With_Same_Root_Hash(t *testing.T t.Fatalf("error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layers, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layers, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unable to create an overlay where root hashes are the same") } @@ -428,7 +786,7 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { sharedMount := "" for i := 0; i < len(containerOne.Layers); i++ { - mount := testDataGenerator.uniqueMountTarget() + mount := testDataGenerator.uniqueLayerMountTarget() err := policy.EnforceDeviceMountPolicy(constraints.ctx, mount, containerOne.Layers[i]) if err != nil { t.Fatalf("Unexpected error mounting overlay device: %v", err) @@ -440,13 +798,14 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { containerOneOverlay[len(containerOneOverlay)-i-1] = mount } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, containerOneOverlay, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, containerOneOverlay, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay: %v", err) } // - // Mount our second contaniers overlay. This should all work. + // Mount our second container overlay. This should all work. // containerID = testDataGenerator.uniqueContainerID() @@ -456,7 +815,7 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { for i := 0; i < len(containerTwo.Layers); i++ { var mount string if i != sharedLayerIndex { - mount = testDataGenerator.uniqueMountTarget() + mount = testDataGenerator.uniqueLayerMountTarget() err := policy.EnforceDeviceMountPolicy(constraints.ctx, mount, containerTwo.Layers[i]) if err != nil { @@ -469,7 +828,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { containerTwoOverlay[len(containerTwoOverlay)-i-1] = mount } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, containerTwoOverlay, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, containerTwoOverlay, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay: %v", err) } @@ -490,12 +850,16 @@ func Test_Rego_EnforceOverlayMountPolicy_Overlay_Single_Container_Twice(t *testi return false } - if err := tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()); err != nil { + overlayTarget := getOverlayMountTarget(tc.containerID) + + if err := tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, overlayTarget); err != nil { t.Errorf("expected nil error got: %v", err) return false } - if err := tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()); err == nil { + if err := tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, overlayTarget); err == nil { t.Errorf("able to create overlay for the same container twice") return false } else { @@ -536,7 +900,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Reusing_ID_Across_Overlays(t *testing.T t.Fatalf("Unexpected error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay filesystem: %v", err) } @@ -547,7 +912,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Reusing_ID_Across_Overlays(t *testing.T t.Fatalf("Unexpected error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err == nil { t.Fatalf("Unexpected success mounting overlay filesystem") } @@ -588,7 +954,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } id := testDataGenerator.uniqueContainerID() - err = policy.EnforceOverlayMountPolicy(constraints.ctx, id, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, id, layerPaths, getOverlayMountTarget(id)) if err != nil { t.Fatalf("failed with %d containers", containersToCreate) } @@ -596,6 +963,90 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } } +func Test_Rego_EnforceOverlayMountPolicy_MountFail(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + securityPolicy := gc.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + tc := selectContainerFromContainerList(gc.containers, testRand) + tid := testDataGenerator.uniqueContainerID() + scratchTarget := getScratchDiskMountTarget(tid) + + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchTarget, true, true, "xfs") + if err != nil { + t.Errorf("failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + layerToErr := testRand.Intn(len(tc.Layers)) + errLayerPathIndex := len(tc.Layers) - layerToErr - 1 + layerPaths := make([]string, len(tc.Layers)) + for i, layerHash := range tc.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + layerPaths[len(tc.Layers)-i-1] = target + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + } else { + rev.Commit() + } + } + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(tid) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPaths), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + layerPathsWithoutErr := make([]string, 0) + for i, layerPath := range layerPaths { + if i != errLayerPathIndex { + layerPathsWithoutErr = append(layerPathsWithoutErr, layerPath) + } + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPathsWithoutErr, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPathsWithoutErr), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + retryTarget := layerPaths[errLayerPathIndex] + rev = startRevertableSection(t, policy) + err = policy.EnforceDeviceMountPolicy(gc.ctx, retryTarget, tc.Layers[layerToErr]) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy again after one previous reverted failure: %v", err) + return false + } + rev.Commit() + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if err != nil { + t.Errorf("failed to EnforceOverlayMountPolicy after one previous reverted failure: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceOverlayMountPolicy_MountFail: %v", err) + } +} + func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoOverlayTest(p, true) @@ -604,7 +1055,7 @@ func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := getOverlayMountTarget(tc.containerID) err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, target) if err != nil { t.Errorf("Failure setting up overlay for testing: %v", err) @@ -633,14 +1084,14 @@ func Test_Rego_EnforceOverlayUnmountPolicy_No_Matches(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := getOverlayMountTarget(tc.containerID) err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, target) if err != nil { t.Errorf("Failure setting up overlay for testing: %v", err) return false } - badTarget := testDataGenerator.uniqueMountTarget() + badTarget := getOverlayMountTarget(generateContainerID(testRand)) err = tc.policy.EnforceOverlayUnmountPolicy(p.ctx, badTarget) if err == nil { t.Errorf("Unexpected policy enforcement success: %v", err) @@ -774,6 +1225,125 @@ func Test_Rego_EnforceEnvironmentVariablePolicy_NotAllMatches(t *testing.T) { } } +func Test_Rego_EnforceEnvironmentVariablePolicy_RegexPatterns(t *testing.T) { + testCases := []struct { + rule string + expectMatches []string + expectNotMatches []string + skipAddAnchors bool + }{ + { + rule: "PREFIX_.+=.+", + expectMatches: []string{"PREFIX_FOO=BAR"}, + expectNotMatches: []string{"PREFIX_FOO=", "SOMETHING=ELSE", "SOMETHING_PREFIX_FOO=BAR"}, + }, + { + rule: "PREFIX_.+=.+BAR", + expectMatches: []string{"PREFIX_FOO=FOO_BAR"}, + expectNotMatches: []string{"PREFIX_FOO=BAR_FOO"}, + }, + { + rule: "SIMPLE_VAR=.+", + expectMatches: []string{"SIMPLE_VAR=FOO"}, + expectNotMatches: []string{"SIMPLE_VAR=", "SOMETHING=ELSE", "SOMETHING=ELSE:SIMPLE_VAR=FOO", "SIMPLE_VAR_FOO=BAR", "SIMPLE_VAR"}, + }, + { + rule: "SIMPLE_VAR=.*", + expectMatches: []string{"SIMPLE_VAR=FOO", "SIMPLE_VAR="}, + expectNotMatches: []string{"SIMPLE_VAR"}, + }, + { + rule: "SIMPLE_VAR=", + expectMatches: []string{"SIMPLE_VAR="}, + expectNotMatches: []string{"SIMPLE_VAR", "SIMPLE_VAR=FOO"}, + }, + { + rule: "", + expectMatches: []string{}, + expectNotMatches: []string{"ANYTHING", "ANYTHING=ELSE"}, + }, + { + rule: "(^PREFIX1|^PREFIX2)=.+$", + expectMatches: []string{"PREFIX1=FOO", "PREFIX2=BAR"}, + expectNotMatches: []string{"PREFIX3_FOO=BAR", "PREFIX1=", "SOMETHING=ELSE", ""}, + skipAddAnchors: true, + }, + } + + testRule := func(rule string, expectMatches, expectNotMatches []string) { + testName := rule + if testName == "" { + testName = "(empty)" + } + t.Run(testName, func(t *testing.T) { + gc := generateConstraints(testRand, 1) + container := selectContainerFromContainerList(gc.containers, testRand) + container.EnvRules = append(container.EnvRules, EnvRuleConfig{ + Strategy: EnvVarRuleRegex, + Rule: rule, + }) + gc.allowEnvironmentVariableDropping = false + + for _, env := range expectMatches { + tc, err := setupRegoCreateContainerTest(gc, container, false) + if err != nil { + t.Error(err) + return + } + + tc.envList = append(tc.envList, env) + envsToKeep, _, _, err := tc.policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + + // getting an error means something is broken + if err != nil { + t.Errorf("Expected container creation to be allowed for env %s. It wasn't: %v", env, err) + return + } + + if !areStringArraysEqual(envsToKeep, tc.envList) { + t.Errorf("Expected env %s to be kept, but it was not in the returned envs: %v", env, envsToKeep) + return + } + } + + for _, env := range expectNotMatches { + tc, err := setupRegoCreateContainerTest(gc, container, false) + if err != nil { + t.Error(err) + return + } + + tc.envList = append(tc.envList, env) + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + + // not getting an error means something is broken + if err == nil { + t.Errorf("Expected container creation not to be allowed for env %s. It was allowed: %v", env, err) + return + } + + envName := strings.Split(env, "=")[0] + assertDecisionJSONContains(t, err, "invalid env list", envName) + } + }) + } + + for _, testCase := range testCases { + if !testCase.skipAddAnchors { + for _, rule := range []string{ + testCase.rule, + "^" + testCase.rule, + testCase.rule + "$", + "^" + testCase.rule + "$", + } { + testRule(rule, testCase.expectMatches, testCase.expectNotMatches) + } + } else { + testRule(testCase.rule, testCase.expectMatches, testCase.expectNotMatches) + } + } +} + func Test_Rego_EnforceEnvironmentVariablePolicy_DropEnvs(t *testing.T) { testFunc := func(gc *generatedConstraints) bool { gc.allowEnvironmentVariableDropping = true @@ -1880,8 +2450,8 @@ func Test_Rego_Enforcement_Point_Allowed(t *testing.T) { t.Fatal(err) } - input := make(map[string]interface{}) - results, err := policy.applyDefaults("__fixture_for_allowed_test_false__", input) + results := make(rpi.RegoQueryResult) + results, err = policy.applyDefaults("__fixture_for_allowed_test_false__", nil, results) if err != nil { t.Fatalf("applied defaults for an enforcement point receieved an error: %v", err) } @@ -1896,8 +2466,8 @@ func Test_Rego_Enforcement_Point_Allowed(t *testing.T) { t.Fatal("result of allowed for an available enforcement point was not the specified default (false)") } - input = make(map[string]interface{}) - results, err = policy.applyDefaults("__fixture_for_allowed_test_true__", input) + results = make(rpi.RegoQueryResult) + results, err = policy.applyDefaults("__fixture_for_allowed_test_true__", nil, results) if err != nil { t.Fatalf("applied defaults for an enforcement point receieved an error: %v", err) } @@ -3326,9 +3896,7 @@ func Test_Rego_Plan9MountPolicy_No_Matches(t *testing.T) { tc.seccomp, ) - if err == nil { - t.Fatal("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "invalid mount list") } func Test_Rego_Plan9MountPolicy_Invalid(t *testing.T) { @@ -3346,6 +3914,21 @@ func Test_Rego_Plan9MountPolicy_Invalid(t *testing.T) { } } +func Test_Rego_Plan9MountPolicy_Invalid_PathTraversal(t *testing.T) { + gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) + + tc, err := setupPlan9MountTest(gc) + if err != nil { + t.Fatalf("unable to setup test: %v", err) + } + + mount := tc.uvmPathForShare + "/../../bin" + err = tc.policy.EnforcePlan9MountPolicy(gc.ctx, mount) + if err == nil { + t.Fatal("Policy enforcement unexpectedly was allowed", err) + } +} + func Test_Rego_Plan9UnmountPolicy(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) @@ -3381,9 +3964,7 @@ func Test_Rego_Plan9UnmountPolicy(t *testing.T) { tc.seccomp, ) - if err == nil { - t.Fatal("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "invalid mount list") } func Test_Rego_Plan9UnmountPolicy_No_Matches(t *testing.T) { @@ -3402,9 +3983,7 @@ func Test_Rego_Plan9UnmountPolicy_No_Matches(t *testing.T) { badMount := randString(testRand, maxPlan9MountTargetLength) err = tc.policy.EnforcePlan9UnmountPolicy(gc.ctx, badMount) - if err == nil { - t.Fatalf("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "no device at path to unmount") } func Test_Rego_GetPropertiesPolicy_On(t *testing.T) { @@ -3525,7 +4104,129 @@ func Test_EnforceRuntimeLogging_Not_Allowed(t *testing.T) { } } -func Test_Rego_LoadFragment_Container(t *testing.T) { +func Test_Rego_LoadFragment_Container(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"containers"}) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + container := tc.containers[0] + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + containerID, err := mountImageForContainer(tc.policy, container.container) + if err != nil { + t.Error("unable to mount image for fragment container: %w", err) + return false + } + + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(p.ctx, + container.sandboxID, + containerID, + copyStrings(container.container.Command), + copyStrings(container.envList), + container.container.WorkingDir, + copyMounts(container.mounts), + false, + container.container.NoNewPrivileges, + container.user, + container.groups, + container.container.User.Umask, + container.capabilities, + container.seccomp, + ) + + if err != nil { + t.Error("unable to create container from fragment: %w", err) + return false + } + + if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + t.Error("module not removed after load") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_Container: %v", err) + } +} + +// Make sure we don't break fragment loading for old policies +func Test_Rego_LoadFragment_Container_Compat_0_10_0(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"containers"}) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + container := tc.containers[0] + rego := getPolicyCode_0_10_0(container.container.Layers[0], fragment.info.issuer, fragment.info.feed) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + containerID, err := mountImageForContainer(tc.policy, container.container) + if err != nil { + t.Error("unable to mount image for fragment container: %w", err) + return false + } + + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(p.ctx, + container.sandboxID, + containerID, + copyStrings(container.container.Command), + copyStrings(container.envList), + container.container.WorkingDir, + copyMounts(container.mounts), + false, + container.container.NoNewPrivileges, + container.user, + container.groups, + container.container.User.Umask, + container.capabilities, + container.seccomp, + ) + + if err != nil { + t.Error("unable to create container from fragment: %w", err) + return false + } + + if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + t.Error("module not removed after load") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_Container_Compat_0_10_0: %v", err) + } +} + +// Make sure we don't break fragment loading for old allow all policies +func Test_Rego_LoadFragment_Container_Compat_0_10_0_allow_all(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"containers"}) if err != nil { @@ -3533,9 +4234,15 @@ func Test_Rego_LoadFragment_Container(t *testing.T) { return false } + rego := policyWith_0_10_0_apiTestAllowAllCode + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy + fragment := tc.fragments[0] container := tc.containers[0] - err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) if err != nil { t.Error("unable to load fragment: %w", err) @@ -3578,7 +4285,7 @@ func Test_Rego_LoadFragment_Container(t *testing.T) { } if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_LoadFragment_Container: %v", err) + t.Errorf("Test_Rego_LoadFragment_Container_Compat_0_10_0: %v", err) } } @@ -3730,6 +4437,173 @@ func Test_Rego_LoadFragment_BadFeed(t *testing.T) { } } +func Test_Rego_parseNamespace(t *testing.T) { + type testCase struct { + inputs []string + expected string + expectFail bool + } + testCases := []testCase{ + { + inputs: []string{ + "package a\nanything-else", + "package a \n\n", + "package a ", + }, + expected: "a", + }, + { + inputs: []string{ + "package aaa", + "package aaa ", + "package aaa\n# anything", + }, + expected: "aaa", + }, + { + inputs: []string{ + "package", + "package\n", + "package ", + "package ", + "package$", + "package aa#bb\nframework", + "package\naa\n", + }, + expectFail: true, + }, + { + inputs: []string{ + "package framework", + "package api", + }, + expectFail: true, + }, + } + + for _, tc := range testCases { + for _, input := range tc.inputs { + result, err := parseNamespace(input) + if tc.expectFail && err == nil { + t.Errorf("Expected failure for input %q, but got success", input) + } else if !tc.expectFail && err != nil { + t.Errorf("Unexpected error for input %q: %v", input, err) + } else if !tc.expectFail && result != tc.expected { + t.Errorf("Expected to parse namespace %q for input %q, but got %q", tc.expected, input, result) + } + } + } +} + +func expectFragmentNotLoaded(t *testing.T, policy *regoEnforcer, issuer, feed string) bool { + if policy.rego.IsModuleActive(rpi.ModuleID(issuer, feed)) { + t.Errorf("fragment module is present") + return false + } + mtdIssuer, err := policy.rego.GetMetadata("issuers", issuer) + if err != nil && !strings.Contains(err.Error(), "value not found") && + !strings.Contains(err.Error(), "metadata not found for name issuers") { + t.Errorf("unexpected error when checking issuer metadata: %v", err) + return false + } + if mtdIssuer != nil || err == nil { + t.Errorf("fragment issuer metadata is present") + return false + } + return true +} + +func Test_Rego_LoadFragment_BadNamespace(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + code := fmt.Sprintf(`package framework + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +enforcement_point_info := { + "available": true, + "unknown": false, + "invalid": false, + "version_missing": false, + "default_results": {"allowed": true}, + "use_framework": true +} +`, fragment.info.minimumSVN, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if err == nil { + t.Error("expected to be unable to load fragment due to bad namespace") + return false + } + + if !strings.Contains(err.Error(), "namespace \"framework\" is reserved") { + t.Errorf("expected error string to contain 'namespace \"framework\" is reserved', but got %q", err.Error()) + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadNamespace: %v", err) + } +} + +func Test_Rego_LoadFragment_BadNamespace2(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + code := fmt.Sprintf(`package #aa +framework + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +`, fragment.info.minimumSVN, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if err == nil { + t.Error("expected to be unable to load fragment due to invalid namespace") + return false + } + + if !strings.Contains(err.Error(), "valid package definition required on first line") { + t.Errorf("expected error string to contain 'valid package definition required on first line', but got %q", err.Error()) + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadNamespace: %v", err) + } +} + func Test_Rego_LoadFragment_InvalidSVN(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoFragmentSVNErrorTestConfig(p) @@ -3750,7 +4624,7 @@ func Test_Rego_LoadFragment_InvalidSVN(t *testing.T) { return false } - if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { t.Error("module not removed upon failure") return false } @@ -3863,7 +4737,7 @@ func Test_Rego_LoadFragment_SVNMismatch(t *testing.T) { return false } - if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { t.Error("module not removed upon failure") return false } @@ -4056,146 +4930,353 @@ func Test_Rego_LoadFragment_ExcludedContainer(t *testing.T) { return false } - _, err = mountImageForContainer(tc.policy, container.container) - if err == nil { - t.Error("expected to be unable to mount image for fragment container") + _, err = mountImageForContainer(tc.policy, container.container) + if err == nil { + t.Error("expected to be unable to mount image for fragment container") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 15, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_ExcludedContainer: %v", err) + } +} + +func Test_Rego_LoadFragment_ExcludedFragment(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithExcludes(p, []string{"fragments"}) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + subFragment := tc.subFragments[0] + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + err = tc.policy.LoadFragment(p.ctx, subFragment.info.issuer, subFragment.info.feed, subFragment.code) + if err == nil { + t.Error("expected to be unable to load a sub-fragment from a fragment") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 15, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_ExcludedFragment: %v", err) + } +} + +func Test_Rego_LoadFragment_ExcludedExternalProcess(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithExcludes(p, []string{"external_processes"}) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + process := tc.externalProcesses[0] + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + envList := buildEnvironmentVariablesFromEnvRules(process.envRules, testRand) + + _, _, err = tc.policy.EnforceExecExternalProcessPolicy(p.ctx, process.command, envList, process.workingDir) + if err == nil { + t.Error("expected to be unable to execute external process from a fragment") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_ExcludedExternalProcess: %v", err) + } +} + +func Test_Rego_LoadFragment_FragmentNamespace(t *testing.T) { + ctx := context.Background() + deviceHash := generateRootHash(testRand) + key := randVariableString(testRand, 32) + value := randVariableString(testRand, 32) + fragmentCode := fmt.Sprintf(`package fragment + +svn := 1 +framework_version := "%s" + +layer := "%s" + +mount_device := {"allowed": allowed, "metadata": [addCustom]} { + allowed := input.deviceHash == layer + addCustom := { + "name": "custom", + "action": "add", + "key": "%s", + "value": "%s" + } +}`, frameworkVersion, deviceHash, key, value) + + issuer := testDataGenerator.uniqueFragmentIssuer() + feed := testDataGenerator.uniqueFragmentFeed() + policyCode := fmt.Sprintf(`package policy + +api_version := "%s" +framework_version := "%s" + +default load_fragment := {"allowed": false} + +check_svn_if_loaded { + not input.fragment_loaded +} else { + data[input.namespace].svn >= 1 +} + +load_fragment := {"allowed": true, "add_module": true} { + input.issuer == "%s" + input.feed == "%s" + check_svn_if_loaded +} + +mount_device := data.fragment.mount_device + `, apiVersion, frameworkVersion, issuer, feed) + + policy, err := newRegoPolicy(policyCode, []oci.Mount{}, []oci.Mount{}, testOSType) + + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + + err = policy.LoadFragment(ctx, issuer, feed, fragmentCode) + if err != nil { + t.Fatalf("unable to load fragment: %v", err) + } + + err = policy.EnforceDeviceMountPolicy(ctx, "/mnt/foo", deviceHash) + if err != nil { + t.Fatalf("unable to mount device: %v", err) + } + + if test, err := policy.rego.GetMetadata("custom", key); err == nil { + if test != value { + t.Error("incorrect metadata value stored by fragment") + } + } else { + t.Errorf("unable to located metadata key stored by fragment: %v", err) + } +} + +func Test_Rego_LoadFragment_BadIssuer_AttemptOverrideFrameworkItems(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + expectedIssuer := fragment.info.issuer + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := fmt.Sprintf(`package fragment + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +data.framework.load_fragment := {"allowed": true, "add_module": true} +input.issuer := "%s" +data.framework.input.issuer := "%s" +`, fragment.info.minimumSVN, frameworkVersion, expectedIssuer, expectedIssuer) + + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) + + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { return false } return true } - if err := quick.Check(f, &quick.Config{MaxCount: 15, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_LoadFragment_ExcludedContainer: %v", err) + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadIssuer_AttemptOverrideFrameworkItems: %v", err) } } -func Test_Rego_LoadFragment_ExcludedFragment(t *testing.T) { +// The intent of this test is really to check that Rego module names are +// case-sensitive, since we do not deny a fragment from having a namespace +// "Framework" or the like. We use svn mismatch here since otherwise the +// enforcer will not even try to load the fragment module at all if issuer or +// feed is wrong. But in reality, if an attacker can sign fragments with the +// correct issuer, they can make the fragment have any SVN they want. +func Test_Rego_LoadFragment_BadSvn_FrameworkNamespaceCaseConfusion(t *testing.T) { f := func(p *generatedConstraints) bool { - tc, err := setupRegoFragmentTestConfigWithExcludes(p, []string{"fragments"}) + tc, err := setupRegoFragmentSVNErrorTestConfig(p) if err != nil { t.Error(err) return false } fragment := tc.fragments[0] - subFragment := tc.subFragments[0] + code := fmt.Sprintf(`package Framework - err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) - if err != nil { - t.Error("unable to load fragment: %w", err) +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +enforcement_point_info := { + "available": true, + "unknown": false, + "invalid": false, + "version_missing": false, + "default_results": {"allowed": true}, + "use_framework": true +} +data.framework.load_fragment := load_fragment +`, fragment.constraints.svn, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if !assertDecisionJSONContains(t, err, "fragment svn is below the specified minimum") { return false } - err = tc.policy.LoadFragment(p.ctx, subFragment.info.issuer, subFragment.info.feed, subFragment.code) - if err == nil { - t.Error("expected to be unable to load a sub-fragment from a fragment") + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { return false } return true } - if err := quick.Check(f, &quick.Config{MaxCount: 15, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_LoadFragment_ExcludedFragment: %v", err) + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadSvn_FrameworkNamespaceCaseConfusion: %v", err) } } -func Test_Rego_LoadFragment_ExcludedExternalProcess(t *testing.T) { +func Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego(t *testing.T) { f := func(p *generatedConstraints) bool { - tc, err := setupRegoFragmentTestConfigWithExcludes(p, []string{"external_processes"}) + tc, err := setupSimpleRegoFragmentTestConfig(p) if err != nil { t.Error(err) return false } fragment := tc.fragments[0] - process := tc.externalProcesses[0] + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := "package fragment\n!invalid!rego" - err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) - if err != nil { - t.Error("unable to load fragment: %w", err) + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) + + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) return false } - - envList := buildEnvironmentVariablesFromEnvRules(process.envRules, testRand) - - _, _, err = tc.policy.EnforceExecExternalProcessPolicy(p.ctx, process.command, envList, process.workingDir) - if err == nil { - t.Error("expected to be unable to execute external process from a fragment") + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + if !expectFragmentNotLoaded(t, tc.policy, actualIssuer, fragment.info.feed) { return false } - return true } if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_LoadFragment_ExcludedExternalProcess: %v", err) + t.Errorf("Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego: %v", err) } } -func Test_Rego_LoadFragment_FragmentNamespace(t *testing.T) { - ctx := context.Background() - deviceHash := generateRootHash(testRand) - key := randVariableString(testRand, 32) - value := randVariableString(testRand, 32) - fragmentCode := fmt.Sprintf(`package fragment +func Test_Rego_LoadFragment_BadFeed_MustNotTryToLoadRego(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } -svn := 1 -framework_version := "%s" + fragment := tc.fragments[0] + actualFeed := testDataGenerator.uniqueFragmentFeed() + code := "package fragment\n!invalid!rego" -layer := "%s" + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, actualFeed, code) -mount_device := {"allowed": allowed, "metadata": [addCustom]} { - allowed := input.deviceHash == layer - addCustom := { - "name": "custom", - "action": "add", - "key": "%s", - "value": "%s" + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) + return false + } + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment feed") { + return false + } + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, actualFeed) { + return false + } + return true } -}`, frameworkVersion, deviceHash, key, value) - - issuer := testDataGenerator.uniqueFragmentIssuer() - feed := testDataGenerator.uniqueFragmentFeed() - policyCode := fmt.Sprintf(`package policy - -api_version := "%s" -framework_version := "%s" - -default load_fragment := {"allowed": false} -load_fragment := {"allowed": true, "add_module": true} { - input.issuer == "%s" - input.feed == "%s" - data[input.namespace].svn >= 1 + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadFeed_MustNotTryToLoadRego: %v", err) + } } -mount_device := data.fragment.mount_device - `, apiVersion, frameworkVersion, issuer, feed) - - policy, err := newRegoPolicy(policyCode, []oci.Mount{}, []oci.Mount{}, testOSType) +func Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego_Compat_0_10_0(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + rego := getPolicyCode_0_10_0(tc.containers[0].container.Layers[0], tc.fragments[0].info.issuer, tc.fragments[0].info.feed) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy - if err != nil { - t.Fatalf("unable to create Rego policy: %v", err) - } + fragment := tc.fragments[0] + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := "package fragment\n!invalid!rego" - err = policy.LoadFragment(ctx, issuer, feed, fragmentCode) - if err != nil { - t.Fatalf("unable to load fragment: %v", err) - } + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) - err = policy.EnforceDeviceMountPolicy(ctx, "/mnt/foo", deviceHash) - if err != nil { - t.Fatalf("unable to mount device: %v", err) + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) + return false + } + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + return true } - if test, err := policy.rego.GetMetadata("custom", key); err == nil { - if test != value { - t.Error("incorrect metadata value stored by fragment") - } - } else { - t.Errorf("unable to located metadata key stored by fragment: %v", err) + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego_Compat_0_10_0: %v", err) } } @@ -4226,6 +5307,9 @@ func Test_Rego_Scratch_Mount_Policy(t *testing.T) { failureExpected: false, }, } { + + filesystem := "xfs" + t.Run(fmt.Sprintf("UnencryptedAllowed_%t_And_Encrypted_%t", tc.unencryptedAllowed, tc.encrypted), func(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) smConfig, err := setupRegoScratchMountTest(gc, tc.unencryptedAllowed) @@ -4233,15 +5317,29 @@ func Test_Rego_Scratch_Mount_Policy(t *testing.T) { t.Fatalf("unable to setup test: %s", err) } - scratchPath := generateMountTarget(testRand) + containerId := testDataGenerator.uniqueContainerID() + scratchDiskMount := getScratchDiskMountTarget(containerId) + + err = smConfig.policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchDiskMount, tc.encrypted, true, filesystem) + if tc.failureExpected { + if err == nil { + t.Fatal("mounting should've been denied") + } + } else { + if err != nil { + t.Fatalf("mounting unexpectedly was denied: %s", err) + } + } + + scratchPath := path.Join(scratchDiskMount, guestpath.ScratchDir, containerId) err = smConfig.policy.EnforceScratchMountPolicy(gc.ctx, scratchPath, tc.encrypted) if tc.failureExpected { if err == nil { - t.Fatal("policy enforcement should've been denied") + t.Fatal("scratch mount should've been denied") } } else { if err != nil { - t.Fatalf("policy enforcement unexpectedly was denied: %s", err) + t.Fatalf("scratch mount unexpectedly was denied: %s", err) } } }) @@ -4280,7 +5378,15 @@ func Test_Rego_Scratch_Unmount_Policy(t *testing.T) { t.Fatalf("unable to setup test: %s", err) } - scratchPath := generateMountTarget(testRand) + containerId := testDataGenerator.uniqueContainerID() + scratchDiskMount := getScratchDiskMountTarget(containerId) + + err = smConfig.policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchDiskMount, tc.encrypted, true, "xfs") + if err != nil { + t.Fatalf("mounting unexpectedly was denied: %s", err) + } + + scratchPath := path.Join(scratchDiskMount, guestpath.ScratchDir, containerId) err = smConfig.policy.EnforceScratchMountPolicy(gc.ctx, scratchPath, tc.encrypted) if err != nil { t.Fatalf("scratch_mount policy enforcement unexpectedly was denied: %s", err) @@ -4290,6 +5396,11 @@ func Test_Rego_Scratch_Unmount_Policy(t *testing.T) { if err != nil { t.Fatalf("scratch_unmount policy enforcement unexpectedly was denied: %s", err) } + + err = smConfig.policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchDiskMount) + if err != nil { + t.Fatalf("device_unmount policy enforcement unexpectedly was denied: %s", err) + } }) } } @@ -4798,7 +5909,7 @@ func Test_FrameworkVersion_Missing(t *testing.T) { layerPaths, err := testDataGenerator.createValidOverlayForContainer(tc.policy, c) - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueLayerMountTarget()) if err == nil { t.Error("unexpected success. Missing framework_version should trigger an error.") } @@ -4834,7 +5945,8 @@ func Test_FrameworkVersion_In_Future(t *testing.T) { layerPaths, err := testDataGenerator.createValidOverlayForContainer(tc.policy, c) - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + gc.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err == nil { t.Error("unexpected success. Future framework_version should trigger an error.") } @@ -5075,6 +6187,195 @@ func Test_Rego_Enforce_CreateContainer_RequiredEnvMissingHasErrorMessage(t *test } } +func Test_Rego_EnforceCreateContainer_RejectRevertedOverlayMount(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + layers, err := testDataGenerator.createValidOverlayForContainer(policy, container) + if err != nil { + t.Errorf("Failed to createValidOverlayForContainer: %v", err) + return false + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(containerID) + err = policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + // Simulate a failure by rolling back the overlay mount + rev.Rollback() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + // "Retry" overlay mount + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying overlay mount: %v", err) + return false + } + rev.Commit() + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + +func Test_Rego_EnforceCreateContainer_RetryEverything(t *testing.T) { + f := func(gc *generatedConstraints, + newContainerID, failScratchMount, testDenyInvalidContainerCreation bool, + ) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + + succeedLayerPaths := make([]string, 0) + + if failScratchMount { + rev.Rollback() + } else { + rev.Commit() + + // Simulate one of the layers failing to mount, after which the outside + // gives up on this container and starts over. + layerToErr := testRand.Intn(len(container.Layers)) + for i, layerHash := range container.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + break + } else { + rev.Commit() + succeedLayerPaths = append(succeedLayerPaths, target) + } + } + + for _, layerPath := range succeedLayerPaths { + rev := startRevertableSection(t, policy) + err = policy.EnforceDeviceUnmountPolicy(gc.ctx, layerPath) + if err != nil { + t.Errorf("Failed to EnforceDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchMountTarget) + if err != nil { + t.Errorf("Failed to EnforceRWDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + if testDenyInvalidContainerCreation { + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + } + rev.Rollback() + } + + if newContainerID { + tc.containerID = testDataGenerator.uniqueContainerID() + } + + err = mountImageForContainerWithID(policy, container, tc.containerID) + if err != nil { + t.Errorf("Failed to mount image for container after reverting and retrying: %v", err) + return false + } + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + func Test_Rego_ExecInContainerPolicy_RequiredEnvMissingHasErrorMessage(t *testing.T) { constraints := generateConstraints(testRand, 1) container := selectContainerFromContainerList(constraints.containers, testRand) @@ -5799,7 +7100,8 @@ func Test_Rego_ErrorTruncation_Unable(t *testing.T) { maxErrorMessageLength := 32 tc.policy.maxErrorMessageLength = maxErrorMessageLength - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + gc.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) if err == nil { t.Fatal("Policy did not throw the expected error") diff --git a/pkg/securitypolicy/securitypolicy_linux.go b/pkg/securitypolicy/securitypolicy_linux.go index cb04e03d92..278038ac67 100644 --- a/pkg/securitypolicy/securitypolicy_linux.go +++ b/pkg/securitypolicy/securitypolicy_linux.go @@ -4,12 +4,14 @@ package securitypolicy import ( + "bytes" "fmt" "os" "path/filepath" "strconv" specInternal "github.com/Microsoft/hcsshim/internal/guest/spec" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" "github.com/moby/sys/user" oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" @@ -18,6 +20,114 @@ import ( //nolint:unused const osType = "linux" +// validateHostData fetches SNP report (if applicable) and validates `hostData` against +// HostData set at UVM launch. +func validateHostData(hostData []byte) error { + // If the UVM is not SNP, then don't try to fetch an SNP report. + if !amdsevsnp.IsSNP() { + return nil + } + report, err := amdsevsnp.FetchParsedSNPReport(nil) + if err != nil { + return err + } + + if !bytes.Equal(hostData, report.HostData) { + return fmt.Errorf( + "security policy digest %q doesn't match HostData provided at launch %q", + hostData, + report.HostData, + ) + } + return nil +} + +func ExtendPolicyWithNetworkingMounts(sandboxID string, enforcer SecurityPolicyEnforcer, spec *oci.Spec) error { + roSpec := &oci.Spec{ + Root: spec.Root, + } + networkingMounts := specInternal.GenerateWorkloadContainerNetworkMounts(sandboxID, roSpec) + if err := enforcer.ExtendDefaultMounts(networkingMounts); err != nil { + return err + } + return nil +} + +func DefaultCRIMounts() []oci.Mount { + return []oci.Mount{ + { + Destination: "/proc", + Type: "proc", + Source: "proc", + Options: []string{"nosuid", "noexec", "nodev"}, + }, + { + Destination: "/dev", + Type: "tmpfs", + Source: "tmpfs", + Options: []string{"nosuid", "strictatime", "mode=755", "size=65536k"}, + }, + { + Destination: "/dev/pts", + Type: "devpts", + Source: "devpts", + Options: []string{"nosuid", "noexec", "newinstance", "ptmxmode=0666", "mode=0620", "gid=5"}, + }, + { + Destination: "/dev/shm", + Type: "tmpfs", + Source: "shm", + Options: []string{"nosuid", "noexec", "nodev", "mode=1777", "size=65536k"}, + }, + { + Destination: "/dev/mqueue", + Type: "mqueue", + Source: "mqueue", + Options: []string{"nosuid", "noexec", "nodev"}, + }, + { + Destination: "/sys", + Type: "sysfs", + Source: "sysfs", + Options: []string{"nosuid", "noexec", "nodev", "ro"}, + }, + { + Destination: "/run", + Type: "tmpfs", + Source: "tmpfs", + Options: []string{"nosuid", "strictatime", "mode=755", "size=65536k"}, + }, + // cgroup mount is always added by default, regardless if it is present + // in the mount constraints or not. If the user chooses to override it, + // then a corresponding mount constraint should be present. + { + Source: "cgroup", + Destination: "/sys/fs/cgroup", + Type: "cgroup", + Options: []string{"nosuid", "noexec", "nodev", "relatime", "ro"}, + }, + } +} + +// DefaultCRIPrivilegedMounts returns a slice of mounts which are added to the +// linux container spec when a container runs in a privileged mode. +func DefaultCRIPrivilegedMounts() []oci.Mount { + return []oci.Mount{ + { + Source: "cgroup", + Destination: "/sys/fs/cgroup", + Type: "cgroup", + Options: []string{"nosuid", "noexec", "nodev", "relatime", "rw"}, + }, + { + Destination: "/sys", + Type: "sysfs", + Source: "sysfs", + Options: []string{"nosuid", "noexec", "nodev", "rw"}, + }, + } +} + // SandboxMountsDir returns sandbox mounts directory inside UVM/host. func SandboxMountsDir(sandboxID string) string { return specInternal.SandboxMountsDir((sandboxID)) diff --git a/pkg/securitypolicy/securitypolicy_options.go b/pkg/securitypolicy/securitypolicy_options.go new file mode 100644 index 0000000000..b2b469e0bb --- /dev/null +++ b/pkg/securitypolicy/securitypolicy_options.go @@ -0,0 +1,227 @@ +package securitypolicy + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" + + "github.com/Microsoft/cosesign1go/pkg/cosesign1" + didx509resolver "github.com/Microsoft/didx509go/pkg/did-x509-resolver" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/pkg/annotations" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type SecurityOptions struct { + // state required for the security policy enforcement + PolicyEnforcer SecurityPolicyEnforcer + PolicyEnforcerSet bool + UvmReferenceInfo string + policyMutex sync.Mutex + logWriter io.Writer +} + +func NewSecurityOptions(enforcer SecurityPolicyEnforcer, enforcerSet bool, uvmReferenceInfo string, logWriter io.Writer) *SecurityOptions { + return &SecurityOptions{ + PolicyEnforcer: enforcer, + PolicyEnforcerSet: enforcerSet, + UvmReferenceInfo: uvmReferenceInfo, + logWriter: logWriter, + } +} + +// SetConfidentialOptions takes guestresource.ConfidentialOptions +// to set up our internal data structures we use to store and enforce +// security policy. The options can contain security policy enforcer type, +// encoded security policy and signed UVM reference information The security +// policy and uvm reference information can be further presented to workload +// containers for validation and attestation purposes. +func (s *SecurityOptions) SetConfidentialOptions(ctx context.Context, enforcerType string, encodedSecurityPolicy string, encodedUVMReference string) error { + s.policyMutex.Lock() + defer s.policyMutex.Unlock() + + if s.PolicyEnforcerSet { + return errors.New("security policy has already been set") + } + + hostData, err := NewSecurityPolicyDigest(encodedSecurityPolicy) + if err != nil { + return err + } + + if err := validateHostData(hostData[:]); err != nil { + return err + } + + // This limit ensures messages are below the character truncation limit that + // can be imposed by an orchestrator + maxErrorMessageLength := 3 * 1024 + + // Initialize security policy enforcer for a given enforcer type and + // encoded security policy. + p, err := CreateSecurityPolicyEnforcer( + enforcerType, + encodedSecurityPolicy, + DefaultCRIMounts(), + DefaultCRIPrivilegedMounts(), + maxErrorMessageLength, + ) + if err != nil { + return fmt.Errorf("error creating security policy enforcer: %w", err) + } + + // This is one of two points at which we might change our logging. + // At this time, we now have a policy and can determine what the policy + // author put as policy around runtime logging. + // The other point is on startup where we take a flag to set the default + // policy enforcer to use before a policy arrives. After that flag is set, + // we use the enforcer in question to set up logging as well. + if err = s.PolicyEnforcer.EnforceRuntimeLoggingPolicy(ctx); err == nil { + logrus.SetOutput(s.logWriter) + } else { + logrus.SetOutput(io.Discard) + } + + s.PolicyEnforcer = p + s.PolicyEnforcerSet = true + s.UvmReferenceInfo = encodedUVMReference + + return nil +} + +// Fragment extends current security policy with additional constraints +// from the incoming fragment. Note that it is base64 encoded over the bridge/ +// +// There are three checking steps: +// 1 - Unpack the cose document and check it was actually signed with the cert +// chain inside its header +// 2 - Check that the issuer field did:x509 identifier is for that cert chain +// (ie fingerprint of a non leaf cert and the subject matches the leaf cert) +// 3 - Check that this issuer/feed match the requirement of the user provided +// security policy (done in the regoby LoadFragment) +func (s *SecurityOptions) InjectFragment(ctx context.Context, fragment *guestresource.SecurityPolicyFragment) (err error) { + log.G(ctx).WithField("fragment", fmt.Sprintf("%+v", fragment)).Debug("VerifyAndExtractFragment") + + raw, err := base64.StdEncoding.DecodeString(fragment.Fragment) + if err != nil { + return fmt.Errorf("failed to decode fragment: %w", err) + } + blob := []byte(fragment.Fragment) + // keep a copy of the fragment, so we can manually figure out what went wrong + // will be removed eventually. Give it a unique name to avoid any potential + // race conditions. + sha := sha256.New() + sha.Write(blob) + timestamp := time.Now() + fragmentPath := fmt.Sprintf("fragment-%x-%d.blob", sha.Sum(nil), timestamp.UnixMilli()) + _ = os.WriteFile(filepath.Join(os.TempDir(), fragmentPath), blob, 0644) + + unpacked, err := cosesign1.UnpackAndValidateCOSE1CertChain(raw) + if err != nil { + return fmt.Errorf("InjectFragment failed COSE validation: %w", err) + } + + payloadString := string(unpacked.Payload[:]) + issuer := unpacked.Issuer + feed := unpacked.Feed + chainPem := unpacked.ChainPem + + log.G(ctx).WithFields(logrus.Fields{ + "issuer": issuer, // eg the DID:x509:blah.... + "feed": feed, + "cty": unpacked.ContentType, + "chainPem": chainPem, + }).Debugf("unpacked COSE1 cert chain") + + log.G(ctx).WithFields(logrus.Fields{ + "payload": payloadString, + }).Tracef("unpacked COSE1 payload") + + if len(issuer) == 0 || len(feed) == 0 { // must both be present + return fmt.Errorf("either issuer and feed must both be provided in the COSE_Sign1 protected header") + } + + // Resolve returns a did doc that we don't need + // we only care if there was an error or not + _, err = didx509resolver.Resolve(unpacked.ChainPem, issuer, true) + if err != nil { + log.G(ctx).Printf("Badly formed fragment - did resolver failed to match fragment did:x509 from chain with purported issuer %s, feed %s - err %s", issuer, feed, err.Error()) + return fmt.Errorf("failed to resolve DID: %w", err) + } + + // now offer the payload fragment to the policy + err = s.PolicyEnforcer.LoadFragment(ctx, issuer, feed, payloadString) + if err != nil { + return fmt.Errorf("error loading security policy fragment: %w", err) + } + return nil +} + +func writeFileInDir(dir string, filename string, data []byte, perm os.FileMode) error { + st, err := os.Stat(dir) + if err != nil { + return err + } + + if !st.IsDir() { + return fmt.Errorf("not a directory %q", dir) + } + + targetFilename := filepath.Join(dir, filename) + return os.WriteFile(targetFilename, data, perm) +} + +// Write security policy, signed UVM reference and host AMD certificate to +// container's rootfs, so that application and sidecar containers can have +// access to it. The security policy is required by containers which need to +// extract init-time claims found in the security policy. The directory path +// containing the files is exposed via UVM_SECURITY_CONTEXT_DIR env var. +// It may be an error to have a security policy but not expose it to the +// container as in that case it can never be checked as correct by a verifier. +func (s *SecurityOptions) WriteSecurityContextDir(spec *specs.Spec) error { + encodedPolicy := s.PolicyEnforcer.EncodedSecurityPolicy() + hostAMDCert := spec.Annotations[annotations.WCOWHostAMDCertificate] + if len(encodedPolicy) > 0 || len(hostAMDCert) > 0 || len(s.UvmReferenceInfo) > 0 { + // Use os.MkdirTemp to make sure that the directory is unique. + securityContextDir, err := os.MkdirTemp(spec.Root.Path, SecurityContextDirTemplate) + if err != nil { + return fmt.Errorf("failed to create security context directory: %w", err) + } + // Make sure that files inside directory are readable + if err := os.Chmod(securityContextDir, 0755); err != nil { + return fmt.Errorf("failed to chmod security context directory: %w", err) + } + + if len(encodedPolicy) > 0 { + if err := writeFileInDir(securityContextDir, PolicyFilename, []byte(encodedPolicy), 0777); err != nil { + return fmt.Errorf("failed to write security policy: %w", err) + } + } + if len(s.UvmReferenceInfo) > 0 { + if err := writeFileInDir(securityContextDir, ReferenceInfoFilename, []byte(s.UvmReferenceInfo), 0777); err != nil { + return fmt.Errorf("failed to write UVM reference info: %w", err) + } + } + + if len(hostAMDCert) > 0 { + if err := writeFileInDir(securityContextDir, HostAMDCertFilename, []byte(hostAMDCert), 0777); err != nil { + return fmt.Errorf("failed to write host AMD certificate: %w", err) + } + } + + containerCtxDir := fmt.Sprintf("/%s", filepath.Base(securityContextDir)) + secCtxEnv := fmt.Sprintf("UVM_SECURITY_CONTEXT_DIR=%s", containerCtxDir) + spec.Process.Env = append(spec.Process.Env, secCtxEnv) + + } + return nil +} diff --git a/pkg/securitypolicy/securitypolicy_windows.go b/pkg/securitypolicy/securitypolicy_windows.go index 6f873fef26..1582a756af 100644 --- a/pkg/securitypolicy/securitypolicy_windows.go +++ b/pkg/securitypolicy/securitypolicy_windows.go @@ -3,11 +3,29 @@ package securitypolicy -import oci "github.com/opencontainers/runtime-spec/specs-go" +import ( + oci "github.com/opencontainers/runtime-spec/specs-go" + "github.com/pkg/errors" +) //nolint:unused const osType = "windows" +// validateHostData fetches SNP report (if applicable) and validates `hostData` against +// HostData set at UVM launch. +func validateHostData(hostData []byte) error { + if err := GetPspDriverError(); err != nil { + // For this case gcs-sidecar will keep initial deny policy. + return errors.Wrapf(err, "an error occurred while using PSP driver") + } + + if err := ValidateHostDataPSP(hostData[:]); err != nil { + // For this case gcs-sidecar will keep initial deny policy. + return err + } + return nil +} + // SandboxMountsDir returns sandbox mounts directory inside UVM/host. func SandboxMountsDir(sandboxID string) string { return "" @@ -21,3 +39,14 @@ func HugePagesMountsDir(sandboxID string) string { func GetAllUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return IDName{}, []IDName{}, "", nil } + +// DefaultCRIMounts returns default mounts added to windows spec by containerD. +func DefaultCRIMounts() []oci.Mount { + return []oci.Mount{} +} + +// DefaultCRIPrivilegedMounts returns a slice of mounts which are added to the +// windows container spec when a container runs in a privileged mode. +func DefaultCRIPrivilegedMounts() []oci.Mount { + return []oci.Mount{} +} diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 127b679ae7..4014640ef7 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -2,22 +2,12 @@ package securitypolicy import ( "context" - "crypto/sha256" - "encoding/base64" "fmt" - "os" - "path/filepath" "syscall" - "time" - "github.com/Microsoft/cosesign1go/pkg/cosesign1" - didx509resolver "github.com/Microsoft/didx509go/pkg/did-x509-resolver" - "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" - "github.com/Microsoft/hcsshim/internal/protocol/guestresource" oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) type createEnforcerFunc func(base64EncodedPolicy string, criMounts, criPrivilegedMounts []oci.Mount, maxErrorMessageLength int) (SecurityPolicyEnforcer, error) @@ -40,7 +30,6 @@ type CreateContainerOptions struct { Capabilities *oci.LinuxCapabilities SeccompProfileSHA256 string } - type SignalContainerOptions struct { IsInitProcess bool // One of these will be set depending on platform @@ -68,9 +57,19 @@ func init() { registeredEnforcers[openDoorEnforcerName] = createOpenDoorEnforcer } +// Represents an in-progress revertable section. To ensure state is consistent, +// Commit() and Rollback() must not fail, so they do not return anything, and if +// an error does occur they should panic. +type RevertableSectionHandle interface { + Commit() + Rollback() +} + type SecurityPolicyEnforcer interface { EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) (err error) + EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) (err error) EnforceDeviceUnmountPolicy(ctx context.Context, unmountTarget string) (err error) + EnforceRWDeviceUnmountPolicy(ctx context.Context, unmountTarget string) (err error) EnforceOverlayMountPolicy(ctx context.Context, containerID string, layerPaths []string, target string) (err error) EnforceOverlayUnmountPolicy(ctx context.Context, target string) (err error) EnforceCreateContainerPolicy( @@ -136,6 +135,7 @@ type SecurityPolicyEnforcer interface { EnforceScratchUnmountPolicy(ctx context.Context, scratchPath string) (err error) GetUserInfo(spec *oci.Process, rootPath string) (IDName, []IDName, string, error) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) (err error) + StartRevertableSection() (RevertableSectionHandle, error) } //nolint:unused @@ -152,69 +152,6 @@ func (s stringSet) contains(item string) bool { return contains } -// Fragment extends current security policy with additional constraints -// from the incoming fragment. Note that it is base64 encoded over the bridge/ -// -// There are three checking steps: -// 1 - Unpack the cose document and check it was actually signed with the cert -// chain inside its header -// 2 - Check that the issuer field did:x509 identifier is for that cert chain -// (ie fingerprint of a non leaf cert and the subject matches the leaf cert) -// 3 - Check that this issuer/feed match the requirement of the user provided -// security policy (done in the regoby LoadFragment) -func ExtractAndVerifyFragment(ctx context.Context, fragment *guestresource.LCOWSecurityPolicyFragment) (issuer string, feed string, payloadString string, err error) { - log.G(ctx).WithField("fragment", fmt.Sprintf("%+v", fragment)).Debug("VerifyAndExtractFragment") - - raw, err := base64.StdEncoding.DecodeString(fragment.Fragment) - if err != nil { - return "", "", "", fmt.Errorf("failed to decode fragment: %w", err) - } - blob := []byte(fragment.Fragment) - // keep a copy of the fragment, so we can manually figure out what went wrong - // will be removed eventually. Give it a unique name to avoid any potential - // race conditions. - sha := sha256.New() - sha.Write(blob) - timestamp := time.Now() - fragmentPath := fmt.Sprintf("fragment-%x-%d.blob", sha.Sum(nil), timestamp.UnixMilli()) - _ = os.WriteFile(filepath.Join(os.TempDir(), fragmentPath), blob, 0644) - - unpacked, err := cosesign1.UnpackAndValidateCOSE1CertChain(raw) - if err != nil { - return "", "", "", fmt.Errorf("InjectFragment failed COSE validation: %w", err) - } - - payloadString = string(unpacked.Payload[:]) - issuer = unpacked.Issuer - feed = unpacked.Feed - chainPem := unpacked.ChainPem - - log.G(ctx).WithFields(logrus.Fields{ - "issuer": issuer, // eg the DID:x509:blah.... - "feed": feed, - "cty": unpacked.ContentType, - "chainPem": chainPem, - }).Debugf("unpacked COSE1 cert chain") - - log.G(ctx).WithFields(logrus.Fields{ - "payload": payloadString, - }).Tracef("unpacked COSE1 payload") - - if len(issuer) == 0 || len(feed) == 0 { // must both be present - return "", "", "", fmt.Errorf("either issuer and feed must both be provided in the COSE_Sign1 protected header") - } - - // Resolve returns a did doc that we don't need - // we only care if there was an error or not - _, err = didx509resolver.Resolve(unpacked.ChainPem, issuer, true) - if err != nil { - log.G(ctx).Printf("Badly formed fragment - did resolver failed to match fragment did:x509 from chain with purported issuer %s, feed %s - err %s", issuer, feed, err.Error()) - return "", "", "", err - } - - return issuer, feed, payloadString, nil -} - // CreateSecurityPolicyEnforcer returns an appropriate enforcer for input // parameters. Returns an error if the requested `enforcer` implementation // isn't registered. @@ -253,6 +190,11 @@ func CreateSecurityPolicyEnforcer( } } +type nopRevertableSectionHandle struct{} + +func (nopRevertableSectionHandle) Commit() {} +func (nopRevertableSectionHandle) Rollback() {} + type OpenDoorSecurityPolicyEnforcer struct { encodedSecurityPolicy string } @@ -274,10 +216,18 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceDeviceMountPolicy(context.Context, return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceRWDeviceMountPolicy(context.Context, string, bool, bool, string) error { + return nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceDeviceUnmountPolicy(context.Context, string) error { return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceRWDeviceUnmountPolicy(context.Context, string) error { + return nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(context.Context, string, []string, string) error { return nil } @@ -383,6 +333,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Cont return nil } +func (*OpenDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} + type ClosedDoorSecurityPolicyEnforcer struct{} var _ SecurityPolicyEnforcer = (*ClosedDoorSecurityPolicyEnforcer)(nil) @@ -391,10 +345,18 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceDeviceMountPolicy(context.Context return errors.New("mounting is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceRWDeviceMountPolicy(context.Context, string, bool, bool, string) error { + return errors.New("Read-write device mounting is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceDeviceUnmountPolicy(context.Context, string) error { return errors.New("unmounting is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceRWDeviceUnmountPolicy(context.Context, string) error { + return errors.New("Read-write device unmounting is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(context.Context, string, []string, string) error { return errors.New("creating an overlay fs is denied by policy") } @@ -499,3 +461,7 @@ func (ClosedDoorSecurityPolicyEnforcer) GetUserInfo(spec *oci.Process, rootPath func (ClosedDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { return nil } + +func (*ClosedDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index bb2fc27530..1393d5109a 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -9,9 +9,13 @@ import ( "encoding/base64" "encoding/json" "fmt" + "regexp" + "slices" "strings" + "sync" "syscall" + "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" rpi "github.com/Microsoft/hcsshim/internal/regopolicyinterpreter" @@ -55,6 +59,10 @@ type regoEnforcer struct { maxErrorMessageLength int // OS type osType string + // Mutex to ensure only one revertable section is active + revertableSectionLock sync.Mutex + // Saved metadata for the revertable section + savedMetadata rpi.SavedMetadata } var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil) @@ -170,7 +178,7 @@ func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oc return policy, nil } -func (policy *regoEnforcer) applyDefaults(enforcementPoint string, results rpi.RegoQueryResult) (rpi.RegoQueryResult, error) { +func (policy *regoEnforcer) applyDefaults(enforcementPoint string, input inputData, results rpi.RegoQueryResult) (rpi.RegoQueryResult, error) { deny := rpi.RegoQueryResult{"allowed": false} info, err := policy.queryEnforcementPoint(enforcementPoint) if err != nil { @@ -182,12 +190,22 @@ func (policy *regoEnforcer) applyDefaults(enforcementPoint string, results rpi.R return deny, fmt.Errorf("rule for %s is missing from policy", enforcementPoint) } + if results.IsEmpty() && info.useFramework { + rule := "data.framework." + enforcementPoint + result, err := policy.rego.Query(rule, input) + if err != nil { + result = nil + } + return result, err + } + return info.defaultResults.Union(results), nil } type enforcementPointInfo struct { availableByPolicyVersion bool defaultResults rpi.RegoQueryResult + useFramework bool } func (policy *regoEnforcer) queryEnforcementPoint(enforcementPoint string) (*enforcementPointInfo, error) { @@ -230,17 +248,23 @@ func (policy *regoEnforcer) queryEnforcementPoint(enforcementPoint string) (*enf defaultResults, err := result.Object("default_results") if err != nil { - return nil, errors.New("enforcement point result missing defaults") + return nil, fmt.Errorf("enforcement point %s result missing defaults", enforcementPoint) } availableByPolicyVersion, err := result.Bool("available") if err != nil { - return nil, errors.New("enforcement point result missing availability info") + return nil, fmt.Errorf("enforcement point %s result missing availability info", enforcementPoint) + } + + useFramework, err := result.Bool("use_framework") + if err != nil { + return nil, fmt.Errorf("enforcement point %s result missing use_framework info", enforcementPoint) } return &enforcementPointInfo{ availableByPolicyVersion: availableByPolicyVersion, defaultResults: defaultResults, + useFramework: useFramework, }, nil } @@ -251,7 +275,7 @@ func (policy *regoEnforcer) enforce(ctx context.Context, enforcementPoint string return nil, policy.denyWithError(ctx, err, input) } - result, err = policy.applyDefaults(enforcementPoint, result) + result, err = policy.applyDefaults(enforcementPoint, input, result) if err != nil { return result, policy.denyWithError(ctx, err, input) } @@ -486,15 +510,34 @@ func (policy *regoEnforcer) redactSensitiveData(input inputData) inputData { } func (policy *regoEnforcer) EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) error { + mountPathRegex := strings.Replace(guestpath.LCOWGlobalScsiMountPrefixFmt, "%d", "[0-9]+", 1) input := inputData{ - "target": target, - "deviceHash": deviceHash, + "target": target, + "deviceHash": deviceHash, + "mountPathRegex": mountPathRegex, } _, err := policy.enforce(ctx, "mount_device", input) return err } +func (policy *regoEnforcer) EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) error { + // At this point we do not know what the container ID would be, so we allow + // any valid IDs. + containerIdRegex := "[0-9a-fA-F]{64}" + mountPathRegex := guestpath.LCOWRootPrefixInUVM + "/" + containerIdRegex + input := inputData{ + "target": target, + "encrypted": encrypted, + "ensureFilesystem": ensureFilesystem, + "filesystem": filesystem, + "mountPathRegex": mountPathRegex, + } + + _, err := policy.enforce(ctx, "rw_mount_device", input) + return err +} + func (policy *regoEnforcer) EnforceOverlayMountPolicy(ctx context.Context, containerID string, layerPaths []string, target string) error { input := inputData{ "containerID": containerID, @@ -768,6 +811,15 @@ func (policy *regoEnforcer) EnforceDeviceUnmountPolicy(ctx context.Context, unmo return err } +func (policy *regoEnforcer) EnforceRWDeviceUnmountPolicy(ctx context.Context, unmountTarget string) error { + input := inputData{ + "unmountTarget": unmountTarget, + } + + _, err := policy.enforce(ctx, "rw_unmount_device", input) + return err +} + func appendMountData(mountData []interface{}, mounts []oci.Mount) []interface{} { for _, mount := range mounts { mountData = append(mountData, inputData{ @@ -1000,14 +1052,45 @@ func (policy *regoEnforcer) EnforceRuntimeLoggingPolicy(ctx context.Context) err return err } +// Rego identifier is a letter or underscore, followed by any number of letters, +// underscores, or digits. See open-policy-agent/opa +// ast/internal/scanner/scanner.go :: scanIdentifier, isLetter +// Technically it also allows other unicode digit characters (but not letters) +// but we do not allow those, for simplicity. +var validNamespaceRegex = `[a-zA-Z_][a-zA-Z0-9_]*` + +// First line of the fragment Rego source code must be a package definition +// without any potential for confusion attacks. We thus limit it to exactly +// "package" followed by one or more spaces, then a valid Rego identifier, then +// optionally more spaces. We do not check if the namespace is a Rego keyword +// (e.g. "in", "every" etc) but it would fail Rego compilation anyway. +var validFirstLine = regexp.MustCompile(`^package +(` + validNamespaceRegex + `)\s*$`) + +// These namespaces must not be overridden by a fragment +var reservedNamespaces []string = []string{ + // Built-in modules + "framework", + "api", + "policy", + // This is not a module, but to prevent confusion since framework uses + // data.metadata to access those, we block it as well. + "metadata", +} + func parseNamespace(rego string) (string, error) { lines := strings.Split(rego, "\n") - parts := strings.Split(lines[0], " ") - if parts[0] != "package" { - return "", errors.New("package definition required on first line") + if lines[0] == "" { + return "", errors.New("Fragment Rego is empty") } - - return strings.TrimSpace(parts[1]), nil + match := validFirstLine.FindStringSubmatch(lines[0]) + if match == nil { + return "", errors.Errorf("valid package definition required on first line, got %q", lines[0]) + } + namespace := match[1] + if slices.Contains(reservedNamespaces, namespace) { + return "", errors.Errorf("namespace %q is reserved and cannot be used for fragments", namespace) + } + return namespace, nil } func (policy *regoEnforcer) LoadFragment(ctx context.Context, issuer string, feed string, rego string) error { @@ -1023,22 +1106,42 @@ func (policy *regoEnforcer) LoadFragment(ctx context.Context, issuer string, fee Namespace: namespace, } - policy.rego.AddModule(fragment.ID(), fragment) - input := inputData{ - "issuer": issuer, - "feed": feed, - "namespace": namespace, + "issuer": issuer, + "feed": feed, + "namespace": namespace, + "fragment_loaded": false, + } + + // Check that the fragment is signed by the expected issuer before loading + // its Rego code. + _, err = policy.enforce(ctx, "load_fragment", input) + if err != nil { + return err } + // At this point we need to add the fragment code as a new Rego module in + // order for the framework (or any user defined policies) to check the SVN, + // and potentially other information defined by its Rego code. We've already + // checked that the fragment is signed correctly, and the namespace is safe + // to load (won't override framework or other built-in modules). Once we + // added the module, we must make sure the module is removed if we return + // with error (or if add_module returned from Rego is false). + policy.rego.AddModule(fragment.ID(), fragment) + input["fragment_loaded"] = true + results, err := policy.enforce(ctx, "load_fragment", input) + if err != nil { + policy.rego.RemoveModule(fragment.ID()) + return err + } addModule, _ := results.Bool("add_module") if !addModule { policy.rego.RemoveModule(fragment.ID()) } - return err + return nil } func (policy *regoEnforcer) EnforceScratchMountPolicy(ctx context.Context, scratchPath string, encrypted bool) error { @@ -1078,3 +1181,81 @@ func (policy *regoEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, conta func (policy *regoEnforcer) GetUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return GetAllUserInfo(process, rootPath) } + +type revertableSectionHandle struct { + // policy is cleared once this struct is "used", to prevent accidental + // duplicate Commit/Rollback calls. + policy *regoEnforcer +} + +func (policy *regoEnforcer) inRevertableSection() bool { + succ := policy.revertableSectionLock.TryLock() + if succ { + // since nobody else has the lock, we're not in fact in a revertable + // section. + policy.revertableSectionLock.Unlock() + return false + } + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Don't unlock it here! + return true +} + +// Starts a revertable section by saving the current policy state. If another +// revertable section is already active, this will wait until that one is +// finished. +func (policy *regoEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + policy.revertableSectionLock.Lock() + var err error + policy.savedMetadata, err = policy.rego.SaveMetadata() + if err != nil { + err = errors.Wrapf(err, "unable to save metadata for revertable section") + policy.revertableSectionLock.Unlock() + return &revertableSectionHandle{}, err + } + // Keep policy.revertableSectionLock locked until the end of the section. + sh := &revertableSectionHandle{ + policy: policy, + } + return sh, nil +} + +func (sh *revertableSectionHandle) Commit() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Clear the saved metadata just in case, then unlock to exit the + // section. + policy.savedMetadata = rpi.SavedMetadata{} + policy.revertableSectionLock.Unlock() + } +} + +func (sh *revertableSectionHandle) Rollback() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Restore the saved metadata, then unlock to exit the section. + err := policy.rego.RestoreMetadata(policy.savedMetadata) + if err != nil { + gcs.UnrecoverableError(errors.Wrap(err, "unable to restore metadata for revertable section")) + } + policy.revertableSectionLock.Unlock() + } +} diff --git a/test/functional/lcow_policy_test.go b/test/functional/lcow_policy_test.go index 43c5fd2090..cf0c857e08 100644 --- a/test/functional/lcow_policy_test.go +++ b/test/functional/lcow_policy_test.go @@ -4,7 +4,9 @@ package functional import ( "context" + "encoding/hex" "fmt" + "math/rand" "testing" ctrdoci "github.com/containerd/containerd/v2/pkg/oci" @@ -22,6 +24,15 @@ import ( testuvm "github.com/Microsoft/hcsshim/test/pkg/uvm" ) +func genValidContainerID(t *testing.T, rng *rand.Rand) string { + t.Helper() + randBytes := make([]byte, 32) + if _, err := rng.Read(randBytes); err != nil { + t.Fatalf("failed to generate random bytes for container ID: %v", err) + } + return hex.EncodeToString(randBytes) +} + func setupScratchTemplate(ctx context.Context, tb testing.TB) string { tb.Helper() opts := defaultLCOWOptions(ctx, tb) @@ -43,6 +54,8 @@ func TestGetProperties_WithPolicy(t *testing.T) { ctx := util.Context(namespacedContext(context.Background()), t) scratchPath := setupScratchTemplate(ctx, t) + rng := rand.New(rand.NewSource(0)) + ls := linuxImageLayers(ctx, t) for _, allowProperties := range []bool{true, false} { t.Run(fmt.Sprintf("AllowPropertiesAccess_%t", allowProperties), func(t *testing.T) { @@ -61,21 +74,24 @@ func TestGetProperties_WithPolicy(t *testing.T) { ) opts.SecurityPolicyEnforcer = "rego" opts.SecurityPolicy = policy + // VPMem is not currently supported for C-LCOW. + opts.VPMemDeviceCount = 0 - cleanName := util.CleanName(t) + containerID := genValidContainerID(t, rng) vm := testuvm.CreateAndStartLCOWFromOpts(ctx, t, opts) spec := testoci.CreateLinuxSpec( ctx, t, - cleanName, + containerID, testoci.DefaultLinuxSpecOpts( "", ctrdoci.WithProcessArgs("/bin/sh", "-c", testoci.TailNullArgs), + ctrdoci.WithEnv(testoci.DefaultUnixEnv), testoci.WithWindowsLayerFolders(append(ls, scratchPath)), )..., ) - c, _, cleanup := testcontainer.Create(ctx, t, vm, spec, cleanName, hcsOwner) + c, _, cleanup := testcontainer.Create(ctx, t, vm, spec, containerID, hcsOwner) t.Cleanup(cleanup) init := testcontainer.Start(ctx, t, c, nil) diff --git a/test/gcs/main_test.go b/test/gcs/main_test.go index f4b32b34c8..ce399e0767 100644 --- a/test/gcs/main_test.go +++ b/test/gcs/main_test.go @@ -22,7 +22,6 @@ import ( "github.com/Microsoft/hcsshim/internal/guest/transport" "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/oc" - "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/Microsoft/hcsshim/test/internal/util" @@ -167,9 +166,9 @@ func getHost(_ context.Context, tb testing.TB, rt runtime.Runtime) *hcsv2.Host { func getHostErr(rt runtime.Runtime, tp transport.Transport) (*hcsv2.Host, error) { h := hcsv2.NewHost(rt, tp, &securitypolicy.OpenDoorSecurityPolicyEnforcer{}, os.Stdout) - if err := h.SetConfidentialUVMOptions( + if err := h.SecurityOptions().SetConfidentialOptions( context.Background(), - &guestresource.LCOWConfidentialOptions{}, + "", "", "", ); err != nil { return nil, fmt.Errorf("could not set host security policy: %w", err) }