4
4
import tarfile
5
5
import traceback
6
6
from abc import ABC , abstractmethod
7
- from typing import Dict
7
+ from functools import lru_cache
8
+ from typing import Dict , List
8
9
9
10
import fsspec
10
11
import nbconvert
11
12
import nbformat
12
13
from nbconvert .preprocessors import CellExecutionError , ExecutePreprocessor
14
+ from prefect import flow , task
15
+ from prefect .futures import as_completed
16
+ from prefect_dask .task_runners import DaskTaskRunner
13
17
14
18
from jupyter_scheduler .models import DescribeJob , JobFeature , Status
15
- from jupyter_scheduler .orm import Job , create_session
19
+ from jupyter_scheduler .orm import Job , Workflow , create_session
16
20
from jupyter_scheduler .parameterize import add_parameters
17
21
from jupyter_scheduler .utils import get_utc_timestamp
22
+ from jupyter_scheduler .workflows import DescribeWorkflow
18
23
19
24
20
25
class ExecutionManager (ABC ):
@@ -29,14 +34,29 @@ class ExecutionManager(ABC):
29
34
_model = None
30
35
_db_session = None
31
36
32
- def __init__ (self , job_id : str , root_dir : str , db_url : str , staging_paths : Dict [str , str ]):
37
+ def __init__ (
38
+ self ,
39
+ job_id : str ,
40
+ workflow_id : str ,
41
+ root_dir : str ,
42
+ db_url : str ,
43
+ staging_paths : Dict [str , str ],
44
+ ):
33
45
self .job_id = job_id
46
+ self .workflow_id = workflow_id
34
47
self .staging_paths = staging_paths
35
48
self .root_dir = root_dir
36
49
self .db_url = db_url
37
50
38
51
@property
39
52
def model (self ):
53
+ if self .workflow_id :
54
+ with self .db_session () as session :
55
+ workflow = (
56
+ session .query (Workflow ).filter (Workflow .workflow_id == self .workflow_id ).first ()
57
+ )
58
+ self ._model = DescribeWorkflow .from_orm (workflow )
59
+ return self ._model
40
60
if self ._model is None :
41
61
with self .db_session () as session :
42
62
job = session .query (Job ).filter (Job .job_id == self .job_id ).first ()
@@ -65,6 +85,18 @@ def process(self):
65
85
else :
66
86
self .on_complete ()
67
87
88
+ def process_workflow (self ):
89
+
90
+ self .before_start_workflow ()
91
+ try :
92
+ self .execute_workflow ()
93
+ except CellExecutionError as e :
94
+ self .on_failure_workflow (e )
95
+ except Exception as e :
96
+ self .on_failure_workflow (e )
97
+ else :
98
+ self .on_complete_workflow ()
99
+
68
100
@abstractmethod
69
101
def execute (self ):
70
102
"""Performs notebook execution,
@@ -74,6 +106,11 @@ def execute(self):
74
106
"""
75
107
pass
76
108
109
+ @abstractmethod
110
+ def execute_workflow (self ):
111
+ """Performs workflow execution"""
112
+ pass
113
+
77
114
@classmethod
78
115
@abstractmethod
79
116
def supported_features (cls ) -> Dict [JobFeature , bool ]:
@@ -98,6 +135,15 @@ def before_start(self):
98
135
)
99
136
session .commit ()
100
137
138
+ def before_start_workflow (self ):
139
+ """Called before start of execute"""
140
+ workflow = self .model
141
+ with self .db_session () as session :
142
+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
143
+ {"status" : Status .IN_PROGRESS }
144
+ )
145
+ session .commit ()
146
+
101
147
def on_failure (self , e : Exception ):
102
148
"""Called after failure of execute"""
103
149
job = self .model
@@ -109,6 +155,17 @@ def on_failure(self, e: Exception):
109
155
110
156
traceback .print_exc ()
111
157
158
+ def on_failure_workflow (self , e : Exception ):
159
+ """Called after failure of execute"""
160
+ workflow = self .model
161
+ with self .db_session () as session :
162
+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
163
+ {"status" : Status .FAILED , "status_message" : str (e )}
164
+ )
165
+ session .commit ()
166
+
167
+ traceback .print_exc ()
168
+
112
169
def on_complete (self ):
113
170
"""Called after job is completed"""
114
171
job = self .model
@@ -118,10 +175,40 @@ def on_complete(self):
118
175
)
119
176
session .commit ()
120
177
178
+ def on_complete_workflow (self ):
179
+ workflow = self .model
180
+ with self .db_session () as session :
181
+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
182
+ {"status" : Status .COMPLETED }
183
+ )
184
+ session .commit ()
185
+
121
186
122
187
class DefaultExecutionManager (ExecutionManager ):
123
188
"""Default execution manager that executes notebooks"""
124
189
190
+ @task (task_run_name = "{task_id}" )
191
+ def execute_task (task_id : str ):
192
+ print (f"Task { task_id } executed" )
193
+ return task_id
194
+
195
+ @flow (task_runner = DaskTaskRunner ())
196
+ def execute_workflow (self ):
197
+ workflow : DescribeWorkflow = self .model
198
+ tasks = {task ["id" ]: task for task in workflow .tasks }
199
+
200
+ # create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
201
+ @lru_cache (maxsize = None )
202
+ def make_task (task_id , execute_task ):
203
+ deps = tasks [task_id ]["dependsOn" ]
204
+ return execute_task .submit (
205
+ task_id , wait_for = [make_task (dep_id , execute_task ) for dep_id in deps ]
206
+ )
207
+
208
+ final_tasks = [make_task (task_id , self .execute_task ) for task_id in tasks ]
209
+ for future in as_completed (final_tasks ):
210
+ print (future .result ())
211
+
125
212
def execute (self ):
126
213
job = self .model
127
214
0 commit comments