diff --git a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql index 62db5118df..1c2b2f5664 100644 --- a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql @@ -395,6 +395,7 @@ begin; references session_state (session_id, end_time) ); + -- Replaced in 91/06_session_state_tstzrange.up.sql create trigger immutable_columns before update on session_state for each row execute procedure immutable_columns('session_id', 'state', 'start_time', 'previous_end_time'); diff --git a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql index 16b9a8e67f..7ecc4a146e 100644 --- a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql +++ b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql @@ -397,6 +397,7 @@ begin; end; $$ language plpgsql; + -- Replaced in 91/06_session_state_tstzrange.up.sql create trigger wh_insert_session_state after insert on session_state for each row execute function wh_insert_session_state(); diff --git a/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql index 831728b3eb..72f82fac8e 100644 --- a/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql +++ b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql @@ -77,6 +77,7 @@ begin end; $$; -- Replaces trigger from 0/50_session.up.sql +-- Replaced in 91/06_session_state_tstzrange.up.sql -- Update insert session state transition trigger drop trigger insert_session_state on session_state; drop function insert_session_state(); diff --git a/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql b/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql index fcd4a2d1c4..93906a3422 100644 --- a/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql +++ b/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql @@ -4,6 +4,7 @@ begin; -- Replaces the view created in 69/02_session_worker_protocol.up.sql + -- drop view session_list; create view session_list as select s.public_id, diff --git a/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql b/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql index 4f52d710ba..a55e0aeed9 100644 --- a/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql +++ b/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql @@ -69,6 +69,7 @@ begin; 'for the user that corresponds to the provided auth_token_id.'; -- Replaces function from 60/03_wh_sessions.up.sql + -- Replaced in 91/06_session_state_tstzrange.up.sql create function wh_insert_session() returns trigger as $$ declare diff --git a/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql b/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql new file mode 100644 index 0000000000..bd046d2871 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql @@ -0,0 +1,184 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + -- Add new active_time_range column that will replace two start_time, end_time columns. + -- Also drop a number of constraints on the start_time, end_time columns. This will allow + -- from dropping these columns after the new column has been set with the correct data. + alter table session_state + add column active_time_range tstzrange not null default tstzrange(now(), null, '[]'), + drop constraint end_times_in_sequence, + drop constraint previous_end_time_and_start_time_in_sequence, + drop constraint start_and_end_times_in_sequence, + drop constraint session_state_session_id_previous_end_time_fkey; + + -- Set the new active_time_range column for any existing rows using start_time and end_time. + update session_state + set active_time_range = tstzrange(start_time, end_time, '[)'); + + -- Replaces view from 72/03/session_list_perf_fix.up.sql + -- Switch view to tuse the new column. This also eliminates the previous_end_time column + -- from the view, since it also will be dropped. + drop view session_list; + create view session_list as + select s.public_id, + s.user_id, + shsh.host_id, + s.target_id, + shsh.host_set_id, + s.auth_token_id, + s.project_id, + s.certificate, + s.certificate_private_key, + s.expiration_time, + s.connection_limit, + s.tofu_token, + s.key_id, + s.termination_reason, + s.version, + s.create_time, + s.update_time, + s.endpoint, + s.worker_filter, + s.egress_worker_filter, + s.ingress_worker_filter, + ss.state, + lower(ss.active_time_range) as start_time, + upper(ss.active_time_range) as end_time + from session s + join session_state ss on s.public_id = ss.session_id + left join session_host_set_host shsh on s.public_id = shsh.session_id; + + -- Now we can finally drop the old columns and add a constraint on the new column + -- that ensures there are no overlaps on the active_time_range for a given session. + alter table session_state + drop column start_time, + drop column end_time, + drop column previous_end_time, + add constraint session_state_active_time_range_excl + exclude using gist (session_id with =, + active_time_range with &&), + add constraint active_time_range_not_empty + check (not isempty(active_time_range)); + + -- There are still a number of functions that reference the old columns. + -- These all need to be updated to use the new column instead. + + -- Replaces trigger from 0/50_session.up.sql + drop trigger immutable_columns on session_state; + create trigger immutable_columns before update on session_state + for each row execute procedure immutable_columns('session_id', 'state'); + + -- Replaces function from 28/02_prior_session_trigger.up.sql + drop trigger insert_session_state on session_state; + drop function insert_session_state(); + create function insert_session_state() returns trigger + as $$ + declare + old_col_state text; + begin + update session_state + set active_time_range = tstzrange(lower(active_time_range), now(), '[)') + where session_id = new.session_id + and upper(active_time_range) is null + returning state + into old_col_state; + + if not found then + new.prior_state = 'pending'; + else + new.prior_state = old_col_state; + end if; + + new.active_time_range = tstzrange(now(), null, '[]'); + + return new; + end; + $$ language plpgsql; + + create trigger insert_session_state before insert on session_state + for each row execute procedure insert_session_state(); + + -- Replaces function from 84/02_wh_upsert_user_refact.up.sql + drop trigger wh_insert_session on session; + drop function wh_insert_session; + create function wh_insert_session() returns trigger + as $$ + declare + new_row wh_session_accumulating_fact%rowtype; + begin + with + pending_timestamp (date_dim_key, time_dim_key, ts) as ( + select wh_date_key(lower(active_time_range)), wh_time_key(lower(active_time_range)), lower(active_time_range) + from session_state + where session_id = new.public_id + and state = 'pending' + ) + insert into wh_session_accumulating_fact ( + session_id, + auth_token_id, + host_key, + user_key, + credential_group_key, + session_pending_date_key, + session_pending_time_key, + session_pending_time + ) + select new.public_id, + new.auth_token_id, + 'no host source', -- will be updated by wh_upsert_host + wh_upsert_user(new.auth_token_id), + 'no credentials', -- will be updated by wh_upsert_credential_group + pending_timestamp.date_dim_key, + pending_timestamp.time_dim_key, + pending_timestamp.ts + from pending_timestamp + returning * into strict new_row; + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session after insert on session + for each row execute procedure wh_insert_session(); + + -- Replaces function from 15/01_wh_rename_key_columns.up.sql + drop trigger wh_insert_session_state on session_state; + drop function wh_insert_session_state; + + create function wh_insert_session_state() returns trigger + as $$ + declare + date_col text; + time_col text; + ts_col text; + q text; + session_row wh_session_accumulating_fact%rowtype; + begin + if new.state = 'pending' then + -- The pending state is the first state which is handled by the + -- wh_insert_session trigger. The update statement in this trigger will + -- fail for the pending state because the row for the session has not yet + -- been inserted into the wh_session_accumulating_fact table. + return null; + end if; + + date_col = 'session_' || new.state || '_date_key'; + time_col = 'session_' || new.state || '_time_key'; + ts_col = 'session_' || new.state || '_time'; + + q = format(' update wh_session_accumulating_fact + set (%I, %I, %I) = (select wh_date_key(%L), wh_time_key(%L), %L::timestamptz) + where session_id = %L + returning *', + date_col, time_col, ts_col, + lower(new.active_time_range), lower(new.active_time_range), lower(new.active_time_range), + new.session_id); + execute q into strict session_row; + + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session_state after insert on session_state + for each row execute function wh_insert_session_state(); +commit; diff --git a/internal/session/immutable_fields_test.go b/internal/session/immutable_fields_test.go index d5efdbe497..4082f5c238 100644 --- a/internal/session/immutable_fields_test.go +++ b/internal/session/immutable_fields_test.go @@ -5,6 +5,7 @@ package session import ( "context" + "fmt" "testing" "github.com/hashicorp/boundary/internal/db" @@ -97,15 +98,35 @@ func TestState_ImmutableFields(t *testing.T) { wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) - ts := timestamp.Timestamp{Timestamp: ×tamppb.Timestamp{Seconds: 0, Nanos: 0}} - _, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) session := TestDefaultSession(t, conn, wrapper, iamRepo) state := TestState(t, conn, session.PublicId, StatusActive) - var new State - err := rw.LookupWhere(context.Background(), &new, "session_id = ? and state = ?", []any{state.SessionId, state.Status}) - require.NoError(t, err) + fetchSession := func(ctx context.Context, rw *db.Db, sessionId string, startTime *timestamp.Timestamp) (*State, error) { + const selectQuery = ` +select session_id, + state, + lower(active_time_range) as start_time, + upper(active_time_range) as end_time + from session_state + where session_id = ? + and lower(active_time_range) = ?;` + var states []*State + rows, err := rw.Query(ctx, selectQuery, []any{sessionId, startTime}) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + if err := rw.ScanRows(ctx, rows, &states); err != nil { + return nil, err + } + } + if len(states) != 1 { + return nil, fmt.Errorf("found %d states, expected 1", len(states)) + } + return states[0], nil + } tests := []struct { name string @@ -115,7 +136,7 @@ func TestState_ImmutableFields(t *testing.T) { { name: "session_id", update: func() *State { - s := new.Clone().(*State) + s := state.Clone().(*State) s.SessionId = "s_thisIsNotAValidId" return s }(), @@ -124,47 +145,28 @@ func TestState_ImmutableFields(t *testing.T) { { name: "status", update: func() *State { - s := new.Clone().(*State) + s := state.Clone().(*State) s.Status = "canceling" return s }(), fieldMask: []string{"Status"}, }, - { - name: "start time", - update: func() *State { - s := new.Clone().(*State) - s.StartTime = &ts - return s - }(), - fieldMask: []string{"StartTime"}, - }, - { - name: "previous_end_time", - update: func() *State { - s := new.Clone().(*State) - s.PreviousEndTime = &ts - return s - }(), - fieldMask: []string{"PreviousEndTime"}, - }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() assert, require := assert.New(t), require.New(t) - orig := new.Clone() - err := rw.LookupWhere(context.Background(), orig, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime}) + orig, err := fetchSession(ctx, rw, state.SessionId, state.StartTime) require.NoError(err) rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true)) require.Error(err) assert.Equal(0, rowsUpdated) - after := new.Clone() - err = rw.LookupWhere(context.Background(), after, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime}) + after, err := fetchSession(ctx, rw, state.SessionId, state.StartTime) require.NoError(err) - assert.Equal(orig.(*State), after) + assert.Equal(orig, after) }) } } diff --git a/internal/session/query.go b/internal/session/query.go index da18eef887..c743092d98 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -90,7 +90,7 @@ active_session as ( where ss.session_id in (select * from unexpired_session) and ss.state = 'active' and - ss.end_time is null + upper(ss.active_time_range) is null ) insert into session_connection ( session_id, @@ -150,7 +150,7 @@ from where ss.session_id = @public_id and ss.state = 'canceling' and - ss.end_time is null + upper(ss.active_time_range) is null ) update session us set version = version +1, @@ -226,7 +226,7 @@ with canceling_session(session_id) as session_state ss where ss.state = 'canceling' and - ss.end_time is null + upper(ss.active_time_range) is null ) update session us set termination_reason = @@ -371,7 +371,7 @@ where and session_state.state = 'terminated' and - session_state.start_time < wt_sub_seconds_from_now(@threshold_seconds) + lower(session_state.active_time_range) < wt_sub_seconds_from_now(@threshold_seconds) ; ` sessionCredentialRewrapQuery = ` @@ -451,6 +451,16 @@ order by update_time desc, public_id desc; ` estimateCountSessions = ` select reltuples::bigint as estimate from pg_class where oid in ('session'::regclass) +` + + selectStates = ` + select session_id, + state, + lower(active_time_range) as start_time, + upper(active_time_range) as end_time + from session_state + where session_id = ? +order by active_time_range desc; ` ) diff --git a/internal/session/repository.go b/internal/session/repository.go index b7e98b56ef..94ffafe56e 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -166,11 +166,10 @@ func (r *Repository) convertToSessions(ctx context.Context, sessionList []*sessi if _, ok := states[sv.EndTime]; !ok { states[sv.EndTime] = &State{ - SessionId: sv.PublicId, - Status: Status(sv.Status), - PreviousEndTime: sv.PreviousEndTime, - StartTime: sv.StartTime, - EndTime: sv.EndTime, + SessionId: sv.PublicId, + Status: Status(sv.Status), + StartTime: sv.StartTime, + EndTime: sv.EndTime, } } diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 5cff87003b..864db5c26d 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -212,7 +212,7 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. if err := read.LookupById(ctx, &session); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) } - states, err := fetchStates(ctx, read, sessionId, db.WithOrder("start_time desc")) + states, err := fetchStates(ctx, read, sessionId) if err != nil { return errors.Wrap(ctx, err, op) } @@ -618,7 +618,7 @@ func (r *Repository) fetchActivatedSessionStatesTx(ctx context.Context, reader d var txErr error var returnedStates []*State - returnedStates, txErr = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + returnedStates, txErr = fetchStates(ctx, reader, sessionId) if txErr != nil { return nil, errors.Wrap(ctx, txErr, op) } @@ -823,7 +823,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV if rowsAffected != 0 && rowsAffected != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected)) } - returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + returnedStates, err = fetchStates(ctx, reader, sessionId) if err != nil { return errors.Wrap(ctx, err, op) } @@ -854,7 +854,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV // non-active state, i.e. "canceling" or "terminated" It returns a *StateReport // object for each session that is not active, with its current status. func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []string) ([]*StateReport, error) { - const op = "session.(Repository).listSessionIdAndState" + const op = "session.(Repository).CheckIfNotActive" notActive := make([]*StateReport, 0, len(reportedSessions)) if len(reportedSessions) <= 0 { @@ -872,7 +872,7 @@ func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []st db.ExpBackoff{}, func(reader db.Reader, _ db.Writer) error { var states []*State - err := reader.SearchWhere(ctx, &states, "end_time is null and session_id in (?)", []any{reportedSessions}) + err := reader.SearchWhere(ctx, &states, "upper(active_time_range) is null and session_id in (?)", []any{reportedSessions}) if err != nil { return errors.Wrap(ctx, err, op) } @@ -926,9 +926,16 @@ func (r *Repository) deleteSessionsTerminatedBefore(ctx context.Context, thresho func fetchStates(ctx context.Context, r db.Reader, sessionId string, opt ...db.Option) ([]*State, error) { const op = "session.fetchStates" var states []*State - if err := r.SearchWhere(ctx, &states, "session_id = ?", []any{sessionId}, opt...); err != nil { + rows, err := r.Query(ctx, selectStates, []any{sessionId}, opt...) + if err != nil { return nil, errors.Wrap(ctx, err, op) } + defer rows.Close() + for rows.Next() { + if err := r.ScanRows(ctx, rows, &states); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + } if len(states) == 0 { return nil, nil } diff --git a/internal/session/session.go b/internal/session/session.go index 063b460eac..ba79394894 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -599,10 +599,9 @@ type sessionListView struct { ProtocolWorkerId string `json:"protocol_worker_id,omitempty" gorm:"default:null"` // State fields - Status string `json:"state,omitempty" gorm:"column:state"` - PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` - StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` - EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"` + Status string `json:"state,omitempty" gorm:"column:state"` + StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` + EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"` } // TableName returns the tablename to override the default gorm table name diff --git a/internal/session/state.go b/internal/session/state.go index 445994206c..a7b0cfac6e 100644 --- a/internal/session/state.go +++ b/internal/session/state.go @@ -54,8 +54,6 @@ type State struct { SessionId string `json:"session_id,omitempty" gorm:"primary_key"` // status of the session Status Status `json:"status,omitempty" gorm:"column:state"` - // PreviousEndTime from the RDBMS - PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` // StartTime from the RDBMS StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` // EndTime from the RDBMS @@ -95,15 +93,6 @@ func (s *State) Clone() any { SessionId: s.SessionId, Status: s.Status, } - if s.PreviousEndTime != nil { - clone.PreviousEndTime = ×tamp.Timestamp{ - Timestamp: ×tamppb.Timestamp{ - Seconds: s.PreviousEndTime.Timestamp.Seconds, - Nanos: s.PreviousEndTime.Timestamp.Nanos, - }, - } - } - if s.StartTime != nil { clone.StartTime = ×tamp.Timestamp{ Timestamp: ×tamppb.Timestamp{ @@ -163,8 +152,5 @@ func (s *State) validate(ctx context.Context) error { if s.EndTime != nil { return errors.New(ctx, errors.InvalidParameter, op, "end time is not settable") } - if s.PreviousEndTime != nil { - return errors.New(ctx, errors.InvalidParameter, op, "previous end time is not settable") - } return nil } diff --git a/internal/session/state_test.go b/internal/session/state_test.go index 25e73f1df4..29a9a86b45 100644 --- a/internal/session/state_test.go +++ b/internal/session/state_test.go @@ -4,160 +4,14 @@ package session import ( - "context" "testing" "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestState_Create(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - - type args struct { - sessionId string - status Status - } - - tests := []struct { - name string - args args - want *State - wantErr bool - wantIsErr errors.Code - create bool - wantCreateErr bool - }{ - { - name: "valid", - args: args{ - sessionId: session.PublicId, - status: StatusActive, - }, - want: &State{ - SessionId: session.PublicId, - Status: StatusActive, - }, - create: true, - }, - { - name: "empty-sessionId", - args: args{ - status: StatusPending, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - { - name: "empty-status", - args: args{ - sessionId: session.PublicId, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - got, err := NewState(context.Background(), tt.args.sessionId, tt.args.status) - if tt.wantErr { - require.Error(err) - assert.True(errors.Match(errors.T(tt.wantIsErr), err)) - return - } - require.NoError(err) - assert.Equal(tt.want, got) - if tt.create { - err = db.New(conn).Create(context.Background(), got) - if tt.wantCreateErr { - assert.Error(err) - return - } else { - assert.NoError(err) - } - } - }) - } -} - -func TestState_Delete(t *testing.T) { - t.Parallel() - ctx := context.Background() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - session2 := TestDefaultSession(t, conn, wrapper, iamRepo) - - tests := []struct { - name string - state *State - deleteStateId string - wantRowsDeleted int - wantErr bool - wantErrMsg string - }{ - { - name: "valid", - state: TestState(t, conn, session.PublicId, StatusTerminated), - wantErr: false, - wantRowsDeleted: 1, - }, - { - name: "bad-id", - state: TestState(t, conn, session2.PublicId, StatusTerminated), - deleteStateId: func() string { - id, err := db.NewPublicId(ctx, StatePrefix) - require.NoError(t, err) - return id - }(), - wantErr: false, - wantRowsDeleted: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - var initialState State - err := rw.LookupWhere(ctx, &initialState, "session_id = ? and state = ?", []any{tt.state.SessionId, tt.state.Status}) - require.NoError(err) - - deleteState := allocState() - if tt.deleteStateId != "" { - deleteState.SessionId = tt.deleteStateId - } else { - deleteState.SessionId = tt.state.SessionId - } - deleteState.StartTime = initialState.StartTime - deletedRows, err := rw.Delete(ctx, &deleteState) - if tt.wantErr { - require.Error(err) - return - } - require.NoError(err) - if tt.wantRowsDeleted == 0 { - assert.Equal(tt.wantRowsDeleted, deletedRows) - return - } - assert.Equal(tt.wantRowsDeleted, deletedRows) - foundState := allocState() - err = rw.LookupWhere(ctx, &foundState, "session_id = ? and start_time = ?", []any{tt.state.SessionId, initialState.StartTime}) - require.Error(err) - assert.True(errors.IsNotFoundError(err)) - }) - } -} - func TestState_Clone(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") diff --git a/internal/session/testing.go b/internal/session/testing.go index 9bd1d60c17..fb9e15548a 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -46,13 +46,22 @@ func TestConnection(t testing.TB, conn *db.DB, sessionId, clientTcpAddr string, // TestState creates a test state for the sessionId in the repository. func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State { + const insertSessionState = ` +insert into session_state (session_id, state, active_time_range) + values ($1, $2, tstzrange($3, null, '[]')) + returning lower(active_time_range) as start_time +;` t.Helper() require := require.New(t) rw := db.New(conn) s, err := NewState(context.Background(), sessionId, state) require.NoError(err) - err = rw.Create(context.Background(), s) + rows, err := rw.Query(context.Background(), insertSessionState, []any{s.SessionId, s.Status, s.StartTime}) require.NoError(err) + defer rows.Close() + for rows.Next() { + rows.Scan(&s.StartTime) + } return s } @@ -142,7 +151,7 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp require.NoError(err) } - ss, err := fetchStates(ctx, rw, s.PublicId, append(opts.withDbOpts, db.WithOrder("start_time desc"))...) + ss, err := fetchStates(ctx, rw, s.PublicId, opts.withDbOpts...) require.NoError(err) s.States = ss