Skip to content

Commit

Permalink
refact(session connection): remove session connection state table
Browse files Browse the repository at this point in the history
  • Loading branch information
irenarindos committed Apr 30, 2024
1 parent e030629 commit ea418af
Show file tree
Hide file tree
Showing 26 changed files with 555 additions and 708 deletions.
16 changes: 5 additions & 11 deletions internal/daemon/cluster/handlers/worker_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,16 +612,13 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs
return nil, status.Errorf(codes.NotFound, "worker not found with name %q", req.GetWorkerId())
}

connectionInfo, connStates, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId())
connectionInfo, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId())
if err != nil {
return nil, err
}
if connectionInfo == nil {
return nil, status.Error(codes.Internal, "Invalid authorize connection response.")
}
if len(connStates) == 0 {
return nil, status.Error(codes.Internal, "Invalid connection state in authorize response.")
}

sessInfo, authzSummary, err := sessionRepo.LookupSession(ctx, req.GetSessionId())
if err != nil {
Expand All @@ -638,7 +635,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs

ret := &pbs.AuthorizeConnectionResponse{
ConnectionId: connectionInfo.GetPublicId(),
Status: connStates[0].Status.ProtoVal(),
Status: session.ConnectionStatusFromString(connectionInfo.ConnectionStatus).ProtoVal(),
ConnectionsLeft: authzSummary.ConnectionLimit,
Route: route,
}
Expand Down Expand Up @@ -670,7 +667,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
}

connectionInfo, connStates, err := connRepo.ConnectConnection(ctx, session.ConnectWith{
connectionInfo, err := connRepo.ConnectConnection(ctx, session.ConnectWith{
ConnectionId: req.GetConnectionId(),
ClientTcpAddress: req.GetClientTcpAddress(),
ClientTcpPort: req.GetClientTcpPort(),
Expand All @@ -686,7 +683,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C
}

return &pbs.ConnectConnectionResponse{
Status: connStates[0].Status.ProtoVal(),
Status: session.ConnectionStatusFromString(connectionInfo.ConnectionStatus).ProtoVal(),
}, nil
}

Expand Down Expand Up @@ -732,12 +729,9 @@ func (ws *workerServiceServer) CloseConnection(ctx context.Context, req *pbs.Clo
if v.Connection == nil {
return nil, status.Errorf(codes.Internal, "No connection found while closing one of the connection IDs: %v", closeIds)
}
if len(v.ConnectionStates) == 0 {
return nil, status.Errorf(codes.Internal, "No connection states found while closing one of the connection IDs: %v", closeIds)
}
closeData = append(closeData, &pbs.CloseConnectionResponseData{
ConnectionId: v.Connection.GetPublicId(),
Status: v.ConnectionStates[0].Status.ProtoVal(),
Status: v.ConnectionState.ProtoVal(),
})
}

Expand Down
18 changes: 8 additions & 10 deletions internal/daemon/cluster/handlers/worker_service_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestStatus(t *testing.T) {
tofu := session.TestTofu(t)
canceledSess, _, err = repo.ActivateSession(ctx, canceledSess.PublicId, canceledSess.Version, tofu)
require.NoError(t, err)
canceledConn, _, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId)
canceledConn, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId)
require.NoError(t, err)

canceledSess, err = repo.CancelSession(ctx, canceledSess.PublicId, canceledSess.Version)
Expand All @@ -119,7 +119,7 @@ func TestStatus(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)

connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)

cases := []struct {
Expand Down Expand Up @@ -560,7 +560,7 @@ func TestStatusSessionClosed(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)

connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)

cases := []struct {
Expand Down Expand Up @@ -751,9 +751,9 @@ func TestStatusDeadConnection(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)

connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)
deadConn, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId)
deadConn, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId)
require.NoError(t, err)
require.NotEqual(t, deadConn.PublicId, connection.PublicId)

Expand Down Expand Up @@ -817,12 +817,10 @@ func TestStatusDeadConnection(t *testing.T) {
),
)

gotConn, states, err := connRepo.LookupConnection(ctx, deadConn.PublicId)
gotConn, err := connRepo.LookupConnection(ctx, deadConn.PublicId)
require.NoError(t, err)
assert.Equal(t, session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason))
assert.Equal(t, 2, len(states))
assert.Nil(t, states[0].EndTime)
assert.Equal(t, session.StatusClosed, states[0].Status)
assert.Equal(t, session.StatusClosed, session.ConnectionStatusFromString(gotConn.ConnectionStatus))
}

func TestStatusWorkerWithKeyId(t *testing.T) {
Expand Down Expand Up @@ -918,7 +916,7 @@ func TestStatusWorkerWithKeyId(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)

connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)

cases := []struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ begin;
create trigger insert_new_session_state after insert on session
for each row execute procedure insert_new_session_state();

-- Updated in 87/01_remove_session_connection_state
-- update_connection_state_on_closed_reason() is used in an update insert trigger on the
-- session_connection table. it will valiadate that all the session's
-- connections are closed, and then insert a state of "closed" in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ begin;
create trigger default_create_time_column before insert on session_connection
for each row execute procedure default_create_time();

-- Removed in 86/01_remove_session_connection_state.up.sql
-- insert_new_connection_state() is used in an after insert trigger on the
-- session_connection table. it will insert a state of "authorized" in
-- session_connection_state for the new session connection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ begin;
drop trigger wh_insert_session_connection_state on session_connection_state;
drop function wh_insert_session_connection_state;

-- Updated in 86/01_remove_session_connection_state.up.sql
create function wh_insert_session_connection_state() returns trigger
as $$
declare
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ begin;
drop trigger update_connection_state_on_closed_reason on session_connection;
drop function update_connection_state_on_closed_reason();

-- Removed in 86/01_remove_session_connection_state.up.sql
create function update_connection_state_on_closed_reason() returns trigger
as $$
begin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ begin;
drop trigger wh_insert_session_connection on session_connection;
drop function wh_insert_session_connection();

-- Updated in 87/01_remove_session_connection_state
create function wh_insert_session_connection() returns trigger
as $$
declare
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1

begin;

-- Remove the session_connection_state table and any related triggers
drop trigger update_connection_state_on_closed_reason on session_connection;
drop function update_connection_state_on_closed_reason();

drop trigger insert_session_connection_state on session_connection_state;
drop function insert_session_connection_state();

drop trigger update_session_state_on_termination_reason on session;
drop function update_session_state_on_termination_reason();

drop trigger insert_new_connection_state on session_connection;
drop function insert_new_connection_state();

drop trigger immutable_columns on session_connection_state;

drop trigger wh_insert_session_connection_state on session_connection_state;
drop function wh_insert_session_connection_state();

drop trigger wh_insert_session_connection on session_connection;
drop function wh_insert_session_connection();

drop table session_connection_state;
drop table session_connection_state_enm;

-- If the connected_time_range is null, it means the connection is authorized but not connected.
-- If the upper value of connected_time_range is > now() (upper range is infinity) then the state is connected.
-- If the upper value of connected_time_range is <= now() then the connection is closed.
alter table session_connection
add column connected_time_range tstzrange;

-- Insert on session_connection creates the connection entry, leaving the connected_time_range to null, indicating the connection is authorized
-- "Connected" is handled by the function ConnectConnection, which sets the connected_time_range lower bound to now() and upper bound to infinity
-- "Closed" is handled by the trigger function, update_connected_time_range_on_closed_reason, which sets the connected_time_range upper bound to now()
-- State transitions are guarded by the trigger function, check_connection_state_transition, which ensures that the state transitions are valid

create or replace function check_connection_state_transition() returns trigger
as $$
begin
-- Authorized state
if new.connected_time_range is null then
return new;
end if;
-- If the old state was authorized, any transition is valid
if old.connected_time_range is null then
return new;
end if;
-- Prevent transitions from connected to connected
if upper(old.connected_time_range) = 'infinity' and upper(new.connected_time_range) = 'infinity' then
raise exception 'Invalid state transition from connected to connected';
end if;
-- Prevent transitions from closed to connected
if lower(new.connected_time_range) >= upper(old.connected_time_range) then
raise exception 'Invalid state transition from closed to connected';
end if;
return new;
end;
$$ language plpgsql;

create trigger check_connection_state_transition before update of connected_time_range on session_connection
for each row execute procedure check_connection_state_transition();

create or replace function update_connected_time_range_on_closed_reason() returns trigger
as $$
begin
if new.closed_reason is not null then
perform from
session_connection cs
where
cs.public_id = new.public_id and
-- If the connection is already closed, there's no need to update the connected_time_range
upper(cs.connected_time_range) <= now();
if not found then
update session_connection
set
connected_time_range = tstzrange(lower(connected_time_range), now())
where
public_id = new.public_id;
end if;
end if;
return new;
end;
$$ language plpgsql;

create trigger update_connected_time_range_closed_reason after update of closed_reason on session_connection
for each row execute procedure update_connected_time_range_on_closed_reason();

create or replace function update_session_state_on_termination_reason() returns trigger
as $$
begin
if new.termination_reason is not null then
perform from
session
where
public_id = new.public_id and
public_id not in (
select session_id
from session_connection
where
-- open connections will have a connected_time_range with an upper bound of infinity
upper(connected_time_range) > now()
);
if not found then
raise 'session %s has open connections', new.public_id;
end if;
-- check to see if there's a terminated state already, before inserting a
-- new one.
perform from
session_state ss
where
ss.session_id = new.public_id and
ss.state = 'terminated';
if found then
return new;
end if;
insert into session_state (session_id, state)
values (new.public_id, 'terminated');
end if;
return new;
end;
$$ language plpgsql;

create trigger update_session_state_on_termination_reason after update of termination_reason on session
for each row execute procedure update_session_state_on_termination_reason();

create or replace function wh_insert_session_connection() returns trigger
as $$
declare
new_row wh_session_connection_accumulating_fact%rowtype;
begin
with
authorized_timestamp (date_dim_key, time_dim_key, ts) as (
select wh_date_key(create_time), wh_time_key(create_time), create_time
from session_connection
where public_id = new.public_id
and connected_time_range is null
),
session_dimension (host_dim_key, user_dim_key, credential_group_dim_key) as (
select host_key, user_key, credential_group_key
from wh_session_accumulating_fact
where session_id = new.session_id
)
insert into wh_session_connection_accumulating_fact (
connection_id,
session_id,
host_key,
user_key,
credential_group_key,
connection_authorized_date_key,
connection_authorized_time_key,
connection_authorized_time,
client_tcp_address,
client_tcp_port_number,
endpoint_tcp_address,
endpoint_tcp_port_number,
bytes_up,
bytes_down
)
select new.public_id,
new.session_id,
session_dimension.host_dim_key,
session_dimension.user_dim_key,
session_dimension.credential_group_dim_key,
authorized_timestamp.date_dim_key,
authorized_timestamp.time_dim_key,
authorized_timestamp.ts,
new.client_tcp_address,
new.client_tcp_port,
new.endpoint_tcp_address,
new.endpoint_tcp_port,
new.bytes_up,
new.bytes_down
from authorized_timestamp,
session_dimension
returning * into strict new_row;
return null;
end;
$$ language plpgsql;

create trigger wh_insert_session_connection after insert on session_connection
for each row execute function wh_insert_session_connection();

create function wh_insert_session_connection_state() returns trigger
as $$
declare
state text;
date_col text;
time_col text;
ts_col text;
q text;
connection_row wh_session_connection_accumulating_fact%rowtype;
begin
if new.connected_time_range is null then
-- Indicates authorized connection. The update statement in this
-- trigger will fail for the authorized state because the row for the
-- session connection has not yet been inserted into the
-- wh_session_connection_accumulating_fact table.
return null;
end if;

if upper(new.connected_time_range) > now() then
state = 'connected';
else
state = 'closed';
end if;

date_col = 'connection_' || state || '_date_key';
time_col = 'connection_' || state || '_time_key';
ts_col = 'connection_' || state || '_time';

q = format('update wh_session_connection_accumulating_fact
set (%I, %I, %I) = (select wh_date_key(%L), wh_time_key(%L), %L::timestamptz)
where connection_id = %L
returning *',
date_col, time_col, ts_col,
new.update_time, new.update_time, new.update_time,
new.public_id);
execute q into strict connection_row;

return null;
end;
$$ language plpgsql;

create trigger wh_insert_session_connection_state after update of connected_time_range on session_connection
for each row execute function wh_insert_session_connection_state();

commit;
Loading

0 comments on commit ea418af

Please sign in to comment.