diff --git a/internal/session/query.go b/internal/session/query.go index da18eef887..4aad9c2d4f 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -329,38 +329,17 @@ where and closed_reason is null returning public_id; ` - orphanedConnectionsCte = ` --- Find connections that are not closed so we can reference those IDs -with - unclosed_connections as ( - select public_id - from session_connection - where - -- It's not closed - upper(connected_time_range) > now() or - connected_time_range is null - -- It's not in limbo between when it moved into this state and when - -- it started being reported by the worker, which is roughly every - -- 2-3 seconds - and update_time < wt_sub_seconds_from_now(@worker_state_delay_seconds) - ), - connections_to_close as ( - select public_id - from session_connection - where - -- Related to the worker that just reported to us - worker_id = @worker_id - -- Only unclosed ones - and public_id in (select public_id from unclosed_connections) - -- These are connection IDs that just got reported to us by the given - -- worker, so they should not be considered closed. - %s - ) + closeOrphanedConnections = ` update session_connection - set - closed_reason = 'system error' - where - public_id in (select public_id from connections_to_close) + set closed_reason = 'system error' + where worker_id = @worker_id + and update_time < wt_sub_seconds_from_now(@worker_state_delay_seconds) + and ( + connected_time_range is null + or + upper(connected_time_range) > now() + ) + %s returning public_id; ` deleteTerminated = ` diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index 9c6f830e6b..339e36dde0 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -399,12 +399,13 @@ func (r *ConnectionRepository) closeOrphanedConnections(ctx context.Context, wor notInClause = fmt.Sprintf(notInClause, strings.Join(params, ",")) } + query := fmt.Sprintf(closeOrphanedConnections, notInClause) _, err := r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - rows, err := w.Query(ctx, fmt.Sprintf(orphanedConnectionsCte, notInClause), args) + rows, err := w.Query(ctx, query, args) if err != nil { return errors.Wrap(ctx, err, op) }