From e5259c2eac789282bbf68f35e83668aa2124dd6f Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Fri, 28 Jul 2023 00:02:11 -0400 Subject: [PATCH] Address PR Feedback - Break up functions - Update struct and return values - Use pointers - Address other small errors --- internal/bsr/bsr.go | 259 +++++++++++++++++------------- internal/bsr/bsr_validate_test.go | 63 +++----- 2 files changed, 171 insertions(+), 151 deletions(-) diff --git a/internal/bsr/bsr.go b/internal/bsr/bsr.go index 9a0bc2da24a..59b5c85eda9 100644 --- a/internal/bsr/bsr.go +++ b/internal/bsr/bsr.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/bsr/kms" "github.com/hashicorp/boundary/internal/storage" "github.com/hashicorp/go-kms-wrapping/v2/extras/crypto" + "github.com/hashicorp/go-multierror" "google.golang.org/protobuf/proto" ) @@ -238,18 +239,20 @@ func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, k return session, nil } +// Validation provides the results from validating a bsr. type Validation struct { - Name string - Type ContainerType - Error error - ChecksumValidation ChecksumValidation - SubContainer []Validation `json:"sub_container,omitempty"` + SessionRecordingId string + Valid bool + SessionRecordingValidation *ContainerValidation } -type ValidationSummary struct { - SessionRecordingId string - Valid bool - FailedChecksums map[ContainerType]map[string]ChecksumValidation +// ContainerValidation contains the results from validating a container in a bsr. +type ContainerValidation struct { + Name string + ContainerType ContainerType + Error *multierror.Error + FileChecksumValidations ChecksumValidation + SubContainers []*ContainerValidation } // Validate retrieves a BSR from storage using the sessionRecordingId and validates the BSR. @@ -257,152 +260,186 @@ type ValidationSummary struct { // each file in SHA256SUM file. // // Validation will continue even if there's an error encountered during validation. -// The validation error will be added to the Validation struct Error field. -func Validate(ctx context.Context, sessionRecordingId string, f storage.FS, keyUnwrapFn kms.KeyUnwrapCallbackFunc) (Validation, ValidationSummary, error) { +// The validation error will be added to the ContainerValidation struct "Error" field for that container. +func Validate(ctx context.Context, sessionRecordingId string, f storage.FS, keyUnwrapFn kms.KeyUnwrapCallbackFunc) (*Validation, error) { const op = "bsr.Validate" - validation := Validation{ - Name: sessionRecordingId, - Type: SessionContainer, - SubContainer: []Validation{}, - } - - validationSummary := ValidationSummary{ + validation := &Validation{ SessionRecordingId: sessionRecordingId, Valid: true, } switch { case sessionRecordingId == "": - validationSummary.Valid = false - return validation, validationSummary, fmt.Errorf("%s: missing session recording id: %w", op, ErrInvalidParameter) + return nil, fmt.Errorf("%s: missing session recording id: %w", op, ErrInvalidParameter) case f == nil: - validationSummary.Valid = false - return validation, validationSummary, fmt.Errorf("%s: missing storage: %w", op, ErrInvalidParameter) + return nil, fmt.Errorf("%s: missing storage: %w", op, ErrInvalidParameter) case keyUnwrapFn == nil: - validationSummary.Valid = false - return validation, validationSummary, fmt.Errorf("%s: missing key unwrap function: %w", op, ErrInvalidParameter) + return nil, fmt.Errorf("%s: missing key unwrap function: %w", op, ErrInvalidParameter) + } + + session, sessionContainerValidation, err := validation.ValidateSession(ctx, sessionRecordingId, f, keyUnwrapFn) + if err != nil { + validation.Valid = false + return validation, fmt.Errorf("%s: failed to valid session for %s: %w", op, sessionRecordingId, err) + } + + // Valid all connections under session + for connId := range session.Meta.connections { + connection, connectionContainerValidation, err := validation.ValidateConnection(ctx, connId, session) + sessionContainerValidation.SubContainers = append(sessionContainerValidation.SubContainers, connectionContainerValidation) + if err != nil { + validation.Valid = false + continue + } + + // Valid all channels under current connection + for chId := range connection.Meta.channels { + _, channelContainerValidation, err := validation.ValidateChannel(ctx, chId, connection) + connectionContainerValidation.SubContainers = append(connectionContainerValidation.SubContainers, channelContainerValidation) + if err != nil { + validation.Valid = false + continue + } + + } + } + + return validation, nil +} + +// ValidateConnection opens session and validates the checksums of all files in the container +func (v *Validation) ValidateSession(ctx context.Context, sessionRecordingId string, f storage.FS, keyUnwrapFn kms.KeyUnwrapCallbackFunc) (*Session, *ContainerValidation, error) { + const op = "bsr.ValidateConnection" + + sessionContainerValidation := &ContainerValidation{ + Name: sessionRecordingId, + ContainerType: SessionContainer, } + v.SessionRecordingValidation = sessionContainerValidation + session, err := OpenSession(ctx, sessionRecordingId, f, keyUnwrapFn) if err != nil { - validationSummary.Valid = false + v.Valid = false validationError := fmt.Errorf("%s: failed to retrieve session for %s: %w", op, sessionRecordingId, err) - validation.Error = errors.Join(validation.Error, validationError) - return validation, validationSummary, validationError + sessionContainerValidation.Error = multierror.Append(sessionContainerValidation.Error, validationError) + return nil, sessionContainerValidation, validationError } - // Validate session sessionChecksumValidation, err := session.container.ValidateChecksums(ctx) if err != nil { - validationSummary.Valid = false - validation.Error = errors.Join(validation.Error, fmt.Errorf("%s: failed to validate session for %s: %w", op, sessionRecordingId, err)) + v.Valid = false + sessionContainerValidation.Error = multierror.Append( + sessionContainerValidation.Error, fmt.Errorf("%s: failed to validate session for %s: %w", op, sessionRecordingId, err)) } - validationSummary = updateValidationSummary(ctx, SessionContainer, sessionRecordingId, sessionChecksumValidation, validationSummary) + sessionContainerValidation.FileChecksumValidations = sessionChecksumValidation - validation.ChecksumValidation = sessionChecksumValidation + v.updateValidationStatus(ctx, sessionContainerValidation) - conns := session.Meta.connections + return session, sessionContainerValidation, nil +} - for connId := range conns { - // Get connection id - connKey := connId - lastDotIndex := strings.LastIndex(connId, ".") - if lastDotIndex != -1 { - connKey = connId[:lastDotIndex] - } +// ValidateConnection opens a connection under a session and validates the checksums of all files in the container +func (v *Validation) ValidateConnection(ctx context.Context, connId string, session *Session) (*Connection, *ContainerValidation, error) { + const op = "bsr.ValidateConnection" - connectionValidation := Validation{ - Name: connId, - Type: ConnectionContainer, - SubContainer: []Validation{}, - } + connectionContainerValidation := &ContainerValidation{ + Name: connId, + ContainerType: ConnectionContainer, + } - conn, err := session.OpenConnection(ctx, connKey) - if err != nil { - validationSummary.Valid = false - connectionValidation.Error = errors.Join(connectionValidation.Error, fmt.Errorf("%s: failed to retrieve connection for %s: %w", op, connId, err)) - continue - } + // Get connection id + connKey := connId + lastDotIndex := strings.LastIndex(connId, ".connection") + if lastDotIndex == -1 { + v.Valid = false + validationError := fmt.Errorf("%s: malformed BSR for: %s", op, connId) + connectionContainerValidation.Error = multierror.Append(connectionContainerValidation.Error, validationError) + return nil, connectionContainerValidation, validationError + } - // Validate current connection - connectionChecksumValidation, err := conn.container.ValidateChecksums(ctx) - if err != nil { - validationSummary.Valid = false - connectionValidation.Error = errors.Join(connectionValidation.Error, fmt.Errorf("%s: failed to validate connection for %s: %w", op, connId, err)) - continue - } + connKey = connId[:lastDotIndex] - validationSummary = updateValidationSummary(ctx, ConnectionContainer, connId, connectionChecksumValidation, validationSummary) + conn, err := session.OpenConnection(ctx, connKey) + if err != nil { + v.Valid = false + validationError := fmt.Errorf("%s: failed to retrieve connection for %s: %w", op, connId, err) + connectionContainerValidation.Error = multierror.Append(connectionContainerValidation.Error, validationError) + return nil, connectionContainerValidation, validationError + } - connectionValidation.ChecksumValidation = connectionChecksumValidation + // Validate current connection + connectionChecksumValidation, err := conn.container.ValidateChecksums(ctx) + if err != nil { + v.Valid = false + connectionContainerValidation.Error = multierror.Append( + connectionContainerValidation.Error, fmt.Errorf("%s: failed to validate connection for %s: %w", op, connId, err)) - chs := conn.Meta.channels + // Return nil error for any iteration to validation to continue with other connections and channels + return nil, connectionContainerValidation, nil + } - for chId := range chs { - // Get channel id - chsKey := chId - lastDotIndex := strings.LastIndex(chId, ".") - if lastDotIndex != -1 { - chsKey = chId[:lastDotIndex] - } + connectionContainerValidation.FileChecksumValidations = connectionChecksumValidation - channelValidation := Validation{ - Name: chId, - Type: ChannelContainer, - SubContainer: []Validation{}, - } + v.updateValidationStatus(ctx, connectionContainerValidation) - ch, err := conn.OpenChannel(ctx, chsKey) - if err != nil { - validationSummary.Valid = false - channelValidation.Error = errors.Join(channelValidation.Error, fmt.Errorf("%s: failed to retrieve channel for %s: %w", op, chId, err)) - continue - } + return conn, connectionContainerValidation, nil +} - // Validate current channel - channelChecksumValidation, err := ch.container.ValidateChecksums(ctx) - if err != nil { - validationSummary.Valid = false - channelValidation.Error = errors.Join(channelValidation.Error, fmt.Errorf("%s: failed to validate channel for %s: %w", op, chId, err)) - continue - } +// ValidateChannel opens a channel under a connection and validates the checksums of all files in the container +func (v *Validation) ValidateChannel(ctx context.Context, chId string, conn *Connection) (*Channel, *ContainerValidation, error) { + const op = "bsr.ValidateChannel" - validationSummary = updateValidationSummary(ctx, ChannelContainer, chId, channelChecksumValidation, validationSummary) - channelValidation.ChecksumValidation = channelChecksumValidation + channelContainerValidation := &ContainerValidation{ + Name: chId, + ContainerType: ChannelContainer, + } - connectionValidation.SubContainer = append(connectionValidation.SubContainer, channelValidation) - } + // Get channel id + chKey := chId + lastDotIndex := strings.LastIndex(chId, ".channel") + if lastDotIndex == -1 { + v.Valid = false + validationError := fmt.Errorf("%s: malformed BSR for: %s", op, chId) + channelContainerValidation.Error = multierror.Append(channelContainerValidation.Error, validationError) + return nil, channelContainerValidation, validationError + } - validation.SubContainer = append(validation.SubContainer, connectionValidation) + chKey = chId[:lastDotIndex] + + ch, err := conn.OpenChannel(ctx, chKey) + if err != nil { + v.Valid = false + validationError := fmt.Errorf("%s: failed to retrieve channel for %s: %w", op, chId, err) + channelContainerValidation.Error = multierror.Append(channelContainerValidation.Error, validationError) + return nil, channelContainerValidation, validationError } - return validation, validationSummary, nil -} + channelChecksumValidation, err := ch.container.ValidateChecksums(ctx) + if err != nil { + v.Valid = false + channelContainerValidation.Error = multierror.Append(channelContainerValidation.Error, fmt.Errorf("%s: failed to validate channel for %s: %w", op, chId, err)) -// updateValidationSummary updates the ValidationSummary based on the failed checksums. -// The updated ValidationSummary is then returned -func updateValidationSummary(ctx context.Context, c ContainerType, id string, cv ChecksumValidation, vs ValidationSummary) ValidationSummary { - failedChecksums := cv.GetFailedItems() + // Return nil error for any iteration for validation to continue on other connections and channels + return ch, channelContainerValidation, nil + } - // Update validation summary only if there are failed tests - if len(failedChecksums) > 0 { - if vs.FailedChecksums == nil { - vs.FailedChecksums = make(map[ContainerType]map[string]ChecksumValidation) - } + channelContainerValidation.FileChecksumValidations = channelChecksumValidation - containerType, ok := vs.FailedChecksums[c] - if !ok { - containerType = make(map[string]ChecksumValidation) - } + v.updateValidationStatus(ctx, channelContainerValidation) - vs.Valid = false + return ch, channelContainerValidation, nil +} - containerType[id] = failedChecksums - vs.FailedChecksums[c] = containerType - } +// updateValidationSummary updates the value of "Valid" field in Validation struct based on if there are failed checksums +func (v *Validation) updateValidationStatus(ctx context.Context, cv *ContainerValidation) { + failedChecksums := cv.FileChecksumValidations.GetFailedItems() - return vs + // Update validation filed only if there are failed checksums + if len(failedChecksums) > 0 { + v.Valid = false + } } // Close closes the Session container. diff --git a/internal/bsr/bsr_validate_test.go b/internal/bsr/bsr_validate_test.go index 433b418b44d..5a805bef0fe 100644 --- a/internal/bsr/bsr_validate_test.go +++ b/internal/bsr/bsr_validate_test.go @@ -63,7 +63,6 @@ func TestBSR_Validate_Valid(t *testing.T) { expectedSessionContainerSize int expectedConnectionContainerSize int expectedChannelContainerSize int - validationSummary ValidationSummary wantErr error }{ { @@ -236,10 +235,6 @@ func TestBSR_Validate_Valid(t *testing.T) { expectedSessionContainerSize: 2, expectedConnectionContainerSize: 2, expectedChannelContainerSize: 0, - validationSummary: ValidationSummary{ - SessionRecordingId: "sr_01234567881", - Valid: true, - }, }, { name: "Failed checksum", @@ -351,55 +346,43 @@ func TestBSR_Validate_Valid(t *testing.T) { expectedSessionContainerSize: 2, expectedConnectionContainerSize: 2, expectedChannelContainerSize: 0, - validationSummary: ValidationSummary{ - SessionRecordingId: "sr_21234567881", - Valid: false, - FailedChecksums: map[ContainerType]map[string]ChecksumValidation{ - SessionContainer: { - "sr_21234567881": ChecksumValidation{ - "Test": &FileChecksumValidation{ - Filename: "Test", - Passed: false, - Error: errors.New("checksum mismatch"), - }, - }, - }, - }, - }, + }, + { + name: "missing session recording id parameter", + wantErr: errors.New("bsr.Validate: missing session recording id: invalid parameter"), }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - validation, summary, err := Validate(ctx, tc.sessionRecordingId, tc.storage, keyFn) + validation, err := Validate(ctx, tc.sessionRecordingId, tc.storage, keyFn) + if tc.wantErr != nil { + assert.EqualError(t, err, tc.wantErr.Error()) + return + } require.NoError(t, err) require.NotNil(t, validation) - require.NotNil(t, summary) - - assert.Equal(t, tc.validationSummary.SessionRecordingId, summary.SessionRecordingId) - assert.Equal(t, tc.validationSummary.Valid, summary.Valid) - assert.Equal(t, tc.validationSummary.FailedChecksums, summary.FailedChecksums) // Validate Session - assert.Equal(t, tc.sessionRecordingId, validation.Name) - assert.Equal(t, SessionContainer, validation.Type) - assert.Equal(t, len(tc.expectedSessionChecksums), len(validation.ChecksumValidation)) - assert.Equal(t, tc.expectedSessionChecksums, validation.ChecksumValidation) + assert.Equal(t, tc.sessionRecordingId, validation.SessionRecordingId) + assert.Equal(t, SessionContainer, validation.SessionRecordingValidation.ContainerType) + assert.Equal(t, len(tc.expectedSessionChecksums), len(validation.SessionRecordingValidation.FileChecksumValidations)) + assert.Equal(t, tc.expectedSessionChecksums, validation.SessionRecordingValidation.FileChecksumValidations) // Validate Multiple Connections - for _, connection := range validation.SubContainer { + for _, connection := range validation.SessionRecordingValidation.SubContainers { require.NotNil(t, connection) - assert.Equal(t, ConnectionContainer, connection.Type) - assert.Equal(t, tc.expectedConnectionContainerSize, len(connection.SubContainer)) - assert.Equal(t, len(tc.expectedConnectionChecksums), len(connection.ChecksumValidation)) - assert.Equal(t, tc.expectedConnectionChecksums, connection.ChecksumValidation) + assert.Equal(t, ConnectionContainer, connection.ContainerType) + assert.Equal(t, tc.expectedConnectionContainerSize, len(connection.SubContainers)) + assert.Equal(t, len(tc.expectedConnectionChecksums), len(connection.FileChecksumValidations)) + assert.Equal(t, tc.expectedConnectionChecksums, connection.FileChecksumValidations) // Validate Multiple Channels - for _, channel := range connection.SubContainer { + for _, channel := range connection.SubContainers { require.NotNil(t, channel) - assert.Equal(t, ChannelContainer, channel.Type) - assert.Equal(t, tc.expectedChannelContainerSize, len(channel.SubContainer)) - assert.Equal(t, len(tc.expectedChannelChecksums), len(channel.ChecksumValidation)) - assert.Equal(t, tc.expectedChannelChecksums, channel.ChecksumValidation) + assert.Equal(t, ChannelContainer, channel.ContainerType) + assert.Equal(t, tc.expectedChannelContainerSize, len(channel.SubContainers)) + assert.Equal(t, len(tc.expectedChannelChecksums), len(channel.FileChecksumValidations)) + assert.Equal(t, tc.expectedChannelChecksums, channel.FileChecksumValidations) } } })