5
5
6
6
import arrow
7
7
from aws_library .ec2 import AWSTagKey , EC2InstanceData
8
+ from aws_library .ec2 ._models import AWSTagValue
8
9
from fastapi import FastAPI
9
10
from models_library .users import UserID
10
11
from models_library .wallets import WalletID
11
12
from pydantic import parse_obj_as
12
13
from servicelib .logging_utils import log_catch
13
-
14
+ from servicelib .utils import limited_gather
15
+
16
+ from ..constants import (
17
+ DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY ,
18
+ DOCKER_STACK_DEPLOY_COMMAND_NAME ,
19
+ ROLE_TAG_KEY ,
20
+ USER_ID_TAG_KEY ,
21
+ WALLET_ID_TAG_KEY ,
22
+ WORKER_ROLE_TAG_VALUE ,
23
+ )
14
24
from ..core .settings import get_application_settings
15
25
from ..modules .clusters import (
16
26
delete_clusters ,
17
27
get_all_clusters ,
18
28
get_cluster_workers ,
19
29
set_instance_heartbeat ,
20
30
)
31
+ from ..utils .clusters import create_deploy_cluster_stack_script
21
32
from ..utils .dask import get_scheduler_auth , get_scheduler_url
22
- from ..utils .ec2 import HEARTBEAT_TAG_KEY
33
+ from ..utils .ec2 import (
34
+ HEARTBEAT_TAG_KEY ,
35
+ get_cluster_name ,
36
+ user_id_from_instance_tags ,
37
+ wallet_id_from_instance_tags ,
38
+ )
23
39
from .dask import is_scheduler_busy , ping_scheduler
40
+ from .ec2 import get_ec2_client
41
+ from .ssm import get_ssm_client
24
42
25
43
_logger = logging .getLogger (__name__ )
26
44
@@ -42,8 +60,8 @@ def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime
42
60
async def _get_all_associated_worker_instances (
43
61
app : FastAPI ,
44
62
primary_instances : Iterable [EC2InstanceData ],
45
- ) -> list [EC2InstanceData ]:
46
- worker_instances = []
63
+ ) -> set [EC2InstanceData ]:
64
+ worker_instances : set [ EC2InstanceData ] = set ()
47
65
for instance in primary_instances :
48
66
assert "user_id" in instance .tags # nosec
49
67
user_id = UserID (instance .tags [_USER_ID_TAG_KEY ])
@@ -55,20 +73,20 @@ async def _get_all_associated_worker_instances(
55
73
else None
56
74
)
57
75
58
- worker_instances .extend (
76
+ worker_instances .update (
59
77
await get_cluster_workers (app , user_id = user_id , wallet_id = wallet_id )
60
78
)
61
79
return worker_instances
62
80
63
81
64
82
async def _find_terminateable_instances (
65
83
app : FastAPI , instances : Iterable [EC2InstanceData ]
66
- ) -> list [EC2InstanceData ]:
84
+ ) -> set [EC2InstanceData ]:
67
85
app_settings = get_application_settings (app )
68
86
assert app_settings .CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec
69
87
70
88
# get the corresponding ec2 instance data
71
- terminateable_instances : list [EC2InstanceData ] = []
89
+ terminateable_instances : set [EC2InstanceData ] = set ()
72
90
73
91
time_to_wait_before_termination = (
74
92
app_settings .CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION
@@ -82,7 +100,7 @@ async def _find_terminateable_instances(
82
100
elapsed_time_since_heartbeat = arrow .utcnow ().datetime - last_heartbeat
83
101
allowed_time_to_wait = time_to_wait_before_termination
84
102
if elapsed_time_since_heartbeat >= allowed_time_to_wait :
85
- terminateable_instances .append (instance )
103
+ terminateable_instances .add (instance )
86
104
else :
87
105
_logger .info (
88
106
"%s has still %ss before being terminateable" ,
@@ -93,14 +111,14 @@ async def _find_terminateable_instances(
93
111
elapsed_time_since_startup = arrow .utcnow ().datetime - instance .launch_time
94
112
allowed_time_to_wait = startup_delay
95
113
if elapsed_time_since_startup >= allowed_time_to_wait :
96
- terminateable_instances .append (instance )
114
+ terminateable_instances .add (instance )
97
115
98
116
# get all terminateable instances associated worker instances
99
117
worker_instances = await _get_all_associated_worker_instances (
100
118
app , terminateable_instances
101
119
)
102
120
103
- return terminateable_instances + worker_instances
121
+ return terminateable_instances . union ( worker_instances )
104
122
105
123
106
124
async def check_clusters (app : FastAPI ) -> None :
@@ -112,6 +130,7 @@ async def check_clusters(app: FastAPI) -> None:
112
130
if await ping_scheduler (get_scheduler_url (instance ), get_scheduler_auth (app ))
113
131
}
114
132
133
+ # set intance heartbeat if scheduler is busy
115
134
for instance in connected_intances :
116
135
with log_catch (_logger , reraise = False ):
117
136
# NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will
@@ -124,6 +143,7 @@ async def check_clusters(app: FastAPI) -> None:
124
143
f"{ instance .id = } for { instance .tags = } " ,
125
144
)
126
145
await set_instance_heartbeat (app , instance = instance )
146
+ # clean any cluster that is not doing anything
127
147
if terminateable_instances := await _find_terminateable_instances (
128
148
app , connected_intances
129
149
):
@@ -138,7 +158,7 @@ async def check_clusters(app: FastAPI) -> None:
138
158
for instance in disconnected_instances
139
159
if _get_instance_last_heartbeat (instance ) is None
140
160
}
141
-
161
+ # remove instances that were starting for too long
142
162
if terminateable_instances := await _find_terminateable_instances (
143
163
app , starting_instances
144
164
):
@@ -149,7 +169,72 @@ async def check_clusters(app: FastAPI) -> None:
149
169
)
150
170
await delete_clusters (app , instances = terminateable_instances )
151
171
152
- # the other instances are broken (they were at some point connected but now not anymore)
172
+ # NOTE: transmit command to start docker swarm/stack if needed
173
+ # once the instance is connected to the SSM server,
174
+ # use ssm client to send the command to these instances,
175
+ # we send a command that contain:
176
+ # the docker-compose file in binary,
177
+ # the call to init the docker swarm and the call to deploy the stack
178
+ instances_in_need_of_deployment = {
179
+ i
180
+ for i in starting_instances - terminateable_instances
181
+ if DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY not in i .tags
182
+ }
183
+
184
+ if instances_in_need_of_deployment :
185
+ app_settings = get_application_settings (app )
186
+ ssm_client = get_ssm_client (app )
187
+ ec2_client = get_ec2_client (app )
188
+ instances_in_need_of_deployment_ssm_connection_state = await limited_gather (
189
+ * [
190
+ ssm_client .is_instance_connected_to_ssm_server (i .id )
191
+ for i in instances_in_need_of_deployment
192
+ ],
193
+ reraise = False ,
194
+ log = _logger ,
195
+ limit = 20 ,
196
+ )
197
+ ec2_connected_to_ssm_server = [
198
+ i
199
+ for i , c in zip (
200
+ instances_in_need_of_deployment ,
201
+ instances_in_need_of_deployment_ssm_connection_state ,
202
+ strict = True ,
203
+ )
204
+ if c is True
205
+ ]
206
+ started_instances_ready_for_command = ec2_connected_to_ssm_server
207
+ if started_instances_ready_for_command :
208
+ # we need to send 1 command per machine here, as the user_id/wallet_id changes
209
+ for i in started_instances_ready_for_command :
210
+ ssm_command = await ssm_client .send_command (
211
+ [i .id ],
212
+ command = create_deploy_cluster_stack_script (
213
+ app_settings ,
214
+ cluster_machines_name_prefix = get_cluster_name (
215
+ app_settings ,
216
+ user_id = user_id_from_instance_tags (i .tags ),
217
+ wallet_id = wallet_id_from_instance_tags (i .tags ),
218
+ is_manager = False ,
219
+ ),
220
+ additional_custom_tags = {
221
+ USER_ID_TAG_KEY : i .tags [USER_ID_TAG_KEY ],
222
+ WALLET_ID_TAG_KEY : i .tags [WALLET_ID_TAG_KEY ],
223
+ ROLE_TAG_KEY : WORKER_ROLE_TAG_VALUE ,
224
+ },
225
+ ),
226
+ command_name = DOCKER_STACK_DEPLOY_COMMAND_NAME ,
227
+ )
228
+ await ec2_client .set_instances_tags (
229
+ started_instances_ready_for_command ,
230
+ tags = {
231
+ DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY : AWSTagValue (
232
+ ssm_command .command_id
233
+ ),
234
+ },
235
+ )
236
+
237
+ # the remaining instances are broken (they were at some point connected but now not anymore)
153
238
broken_instances = disconnected_instances - starting_instances
154
239
if terminateable_instances := await _find_terminateable_instances (
155
240
app , broken_instances
0 commit comments