Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Cherrypick commits to rel 0.13.x #3531

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions internal/bsr/bsr.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func persistBsrSessionKeys(ctx context.Context, keys *kms.Keys, c *container) er
// Signature and checksum files will then be verified.
// Fields on the underlying container will be populated so that the returned Session can be used for BSR
// playback and conversion to formats such as asciinema
func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, keyUnwrapFn kms.KeyUnwrapCallbackFunc) (*Session, error) {
func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, keyUnwrapFn kms.KeyUnwrapCallbackFunc) (s *Session, err error) {
const op = "bsr.OpenSession"
switch {
case sessionRecordingId == "":
Expand All @@ -199,8 +199,15 @@ func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, k
// Load and verify recording metadata
sha256Reader, err := crypto.NewSha256SumReader(ctx, cc.metaFile)
if err != nil {
cc.metaFile.Close()
return nil, fmt.Errorf("%s: %w", op, err)
}
defer func() {
if closeErr := sha256Reader.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("%s: %w", op, closeErr))
}
}()

meta, err := decodeSessionRecordingMeta(ctx, sha256Reader)
if err != nil {
return nil, err
Expand All @@ -226,7 +233,10 @@ func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, k

// Close closes the Session container.
func (s *Session) Close(ctx context.Context) error {
return s.container.close(ctx)
if !is.Nil(s.container) {
return s.container.close(ctx)
}
return nil
}

// Connection is a container in a bsr for a specific connection in a session
Expand Down Expand Up @@ -273,7 +283,7 @@ func (s *Session) NewConnection(ctx context.Context, meta *ConnectionRecordingMe
}

// OpenConnection will open and validate a BSR connection
func (s *Session) OpenConnection(ctx context.Context, connId string) (*Connection, error) {
func (s *Session) OpenConnection(ctx context.Context, connId string) (conn *Connection, err error) {
const op = "bsr.(Session).OpenConnection"
switch {
case connId == "":
Expand Down Expand Up @@ -303,8 +313,15 @@ func (s *Session) OpenConnection(ctx context.Context, connId string) (*Connectio
// Load and verify connection metadata
sha256Reader, err := crypto.NewSha256SumReader(ctx, cc.metaFile)
if err != nil {
cc.metaFile.Close()
return nil, fmt.Errorf("%s: %w", op, err)
}
defer func() {
if closeErr := sha256Reader.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("%s: %w", op, closeErr))
}
}()

sm, err := decodeConnectionRecordingMeta(ctx, sha256Reader)
if err != nil {
return nil, err
Expand Down Expand Up @@ -362,7 +379,7 @@ func (c *Connection) NewChannel(ctx context.Context, meta *ChannelRecordingMeta)
}

// OpenChannel will open and validate a BSR channel
func (c *Connection) OpenChannel(ctx context.Context, chanId string) (*Channel, error) {
func (c *Connection) OpenChannel(ctx context.Context, chanId string) (ch *Channel, err error) {
const op = "bsr.OpenChannel"
switch {
case chanId == "":
Expand Down Expand Up @@ -391,8 +408,15 @@ func (c *Connection) OpenChannel(ctx context.Context, chanId string) (*Channel,
// Load and verify channel metadata
sha256Reader, err := crypto.NewSha256SumReader(ctx, cc.metaFile)
if err != nil {
cc.metaFile.Close()
return nil, fmt.Errorf("%s: %w", op, err)
}
defer func() {
if closeErr := sha256Reader.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("%s: %w", op, closeErr))
}
}()

sm, err := decodeChannelRecordingMeta(ctx, sha256Reader)
if err != nil {
return nil, err
Expand Down Expand Up @@ -456,7 +480,10 @@ func (c *Connection) NewRequestsWriter(ctx context.Context, dir Direction) (io.W

// Close closes the Connection container.
func (c *Connection) Close(ctx context.Context) error {
return c.container.close(ctx)
if !is.Nil(c.container) {
return c.container.close(ctx)
}
return nil
}

// Channel is a container in a bsr for a specific channel in a session
Expand All @@ -469,7 +496,10 @@ type Channel struct {

// Close closes the Channel container.
func (c *Channel) Close(ctx context.Context) error {
return c.container.close(ctx)
if !is.Nil(c.container) {
return c.container.close(ctx)
}
return nil
}

// NewMessagesWriter creates a writer for recording channel messages.
Expand Down
99 changes: 99 additions & 0 deletions internal/bsr/bsr_open_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package bsr

import (
"context"
"fmt"
"testing"

"github.com/hashicorp/boundary/internal/bsr/internal/fstest"
Expand Down Expand Up @@ -293,3 +294,101 @@ func TestOpenChannel(t *testing.T) {
})
}
}

func TestCloseBSRMethods(t *testing.T) {
ctx := context.Background()

protocol := Protocol("TEST_CLOSED_FILE")
sessionRecordingId := "sr_012344567890"
sessionId := "s_012344567890"
connectionId := "connection"
channelId := "channel"

keys, err := kms.CreateKeys(ctx, kms.TestWrapper(t), "session")
require.NoError(t, err)

require.NoError(t, err)

f := &fstest.MemFS{}
srm := &SessionRecordingMeta{
Id: sessionRecordingId,
Protocol: protocol,
}
sessionMeta := TestSessionMeta(sessionId)

sesh, err := NewSession(ctx, srm, sessionMeta, f, keys, WithSupportsMultiplex(true))
require.NoError(t, err)
require.NotNil(t, sesh)

connMeta := &ConnectionRecordingMeta{Id: connectionId}
conn, err := sesh.NewConnection(ctx, connMeta)
require.NoError(t, err)
require.NotNil(t, conn)

chanMeta := &ChannelRecordingMeta{
Id: channelId,
Type: "chan",
}
ch, err := conn.NewChannel(ctx, chanMeta)
require.NoError(t, err)
require.NotNil(t, ch)

ch.Close(ctx)
conn.Close(ctx)
sesh.Close(ctx)

keyFn := func(w kms.WrappedKeys) (kms.UnwrappedKeys, error) {
u := kms.UnwrappedKeys{
BsrKey: keys.BsrKey,
PrivKey: keys.PrivKey,
}
return u, nil
}

opSesh, err := OpenSession(ctx, srm.Id, f, keyFn)
require.NoError(t, err)
require.NotNil(t, opSesh)

opConn, err := opSesh.OpenConnection(ctx, connectionId)
require.NoError(t, err)
require.NotNil(t, opConn)

opChan, err := opConn.OpenChannel(ctx, channelId)
require.NoError(t, err)
require.NotNil(t, opChan)

// Close all opened containers
require.NoError(t, opChan.Close(ctx))
require.NoError(t, opConn.Close(ctx))
require.NoError(t, opSesh.Close(ctx))

// Get session container
sessionContainer := f.Containers[fmt.Sprintf(bsrFileNameTemplate, sessionRecordingId)]
require.NotNil(t, sessionContainer)
assert.True(t, sessionContainer.Closed)

// Ensure all session files are closed
for _, file := range sessionContainer.Files {
assert.True(t, file.Closed)
}

// Get connection container
connectionContainer := sessionContainer.Sub[fmt.Sprintf(connectionFileNameTemplate, connectionId)]
require.NotNil(t, connectionContainer)
assert.True(t, connectionContainer.Closed)

// Ensure all connection files are closed
for _, file := range connectionContainer.Files {
assert.True(t, file.Closed)
}

// Get channel container
channelContainer := connectionContainer.Sub[fmt.Sprintf(channelFileNameTemplate, channelId)]
require.NotNil(t, channelContainer)
assert.True(t, channelContainer.Closed)

// Ensure all channel files are closed
for _, file := range channelContainer.Files {
assert.True(t, file.Closed)
}
}
64 changes: 49 additions & 15 deletions internal/bsr/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,20 @@ func openContainer(ctx context.Context, t containerType, c storage.Container, ke
return cc, nil
}

func (c *container) loadChecksums(ctx context.Context) error {
func (c *container) loadChecksums(ctx context.Context) (err error) {
const op = "bsr.(container).loadChecksums"

// Open and extract checksum signature
checksumsSigFile, err := c.container.OpenFile(ctx, sigFileName)
if err != nil {
return err
}
defer func() {
if closeErr := checksumsSigFile.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("%s: %w", op, closeErr))
}
}()

checksumSigBytes := new(bytes.Buffer)
_, err = checksumSigBytes.ReadFrom(checksumsSigFile)
if err != nil {
Expand All @@ -186,6 +192,12 @@ func (c *container) loadChecksums(ctx context.Context) error {
if err != nil {
return err
}
defer func() {
if closeErr := checksumsFile.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("%s: %w", op, closeErr))
}
}()

var checksumsBuffer bytes.Buffer
cTee := io.TeeReader(checksumsFile, &checksumsBuffer)

Expand Down Expand Up @@ -213,11 +225,16 @@ func (c *container) loadChecksums(ctx context.Context) error {
return nil
}

func (c *container) loadKey(ctx context.Context, keyFileName string) (*wrapping.KeyInfo, error) {
func (c *container) loadKey(ctx context.Context, keyFileName string) (k *wrapping.KeyInfo, err error) {
keyFile, err := c.container.OpenFile(ctx, keyFileName)
if err != nil {
return nil, err
}
defer func() {
if closeErr := keyFile.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}()

keyBytes := new(bytes.Buffer)
_, err = keyBytes.ReadFrom(keyFile)
Expand All @@ -234,11 +251,16 @@ func (c *container) loadKey(ctx context.Context, keyFileName string) (*wrapping.
return key, nil
}

func (c *container) loadSignature(ctx context.Context, sigFileName string) (*wrapping.SigInfo, error) {
func (c *container) loadSignature(ctx context.Context, sigFileName string) (s *wrapping.SigInfo, err error) {
sigFile, err := c.container.OpenFile(ctx, sigFileName)
if err != nil {
return nil, err
}
defer func() {
if closeErr := sigFile.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}()

sigBytes := new(bytes.Buffer)
_, err = sigBytes.ReadFrom(sigFile)
Expand Down Expand Up @@ -488,28 +510,40 @@ func (c *container) close(_ context.Context) error {

var closeError error

if err := c.meta.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.meta) {
if err := c.meta.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

if err := c.sum.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.sum) {
if err := c.sum.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

if err := c.checksums.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.checksums) {
if err := c.checksums.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

if err := c.sigs.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.sigs) {
if err := c.sigs.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

if err := c.journal.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.journal) {
if err := c.journal.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

if err := c.container.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
if !is.Nil(c.container) {
if err := c.container.Close(); err != nil {
closeError = errors.Join(closeError, fmt.Errorf("%s: %w", op, err))
}
}

return closeError
Expand Down
4 changes: 4 additions & 0 deletions internal/bsr/convert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ func ToAsciicast(ctx context.Context, session *bsr.Session, tmp storage.TempFile
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
defer conn.Close(ctx)

ch, err := conn.OpenChannel(ctx, chanId)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
defer ch.Close(ctx)

// TODO sanity checks before getting the data files:
// - check connection summary to see if there was an exec or shell request
Expand All @@ -62,11 +64,13 @@ func ToAsciicast(ctx context.Context, session *bsr.Session, tmp storage.TempFile
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
defer reqScanner.Close()

msgScanner, err := ch.OpenMessageScanner(ctx, bsr.Outbound)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
defer msgScanner.Close()

return sshChannelToAsciicast(ctx, reqScanner, msgScanner, tmp, options...)
default:
Expand Down
Loading