Skip to content

Commit 6c9de64

Browse files
committed
add workflow handler and endpoint
1 parent abedadf commit 6c9de64

File tree

5 files changed

+186
-4
lines changed

5 files changed

+186
-4
lines changed

jupyter_scheduler/executors.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,22 @@
44
import tarfile
55
import traceback
66
from abc import ABC, abstractmethod
7-
from typing import Dict
7+
from functools import lru_cache
8+
from typing import Dict, List
89

910
import fsspec
1011
import nbconvert
1112
import nbformat
1213
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor
14+
from prefect import flow, task
15+
from prefect.futures import as_completed
16+
from prefect_dask.task_runners import DaskTaskRunner
1317

1418
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
15-
from jupyter_scheduler.orm import Job, create_session
19+
from jupyter_scheduler.orm import Job, Workflow, create_session
1620
from jupyter_scheduler.parameterize import add_parameters
1721
from jupyter_scheduler.utils import get_utc_timestamp
22+
from jupyter_scheduler.workflows import DescribeWorkflow
1823

1924

2025
class ExecutionManager(ABC):
@@ -29,14 +34,29 @@ class ExecutionManager(ABC):
2934
_model = None
3035
_db_session = None
3136

32-
def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
37+
def __init__(
38+
self,
39+
job_id: str,
40+
workflow_id: str,
41+
root_dir: str,
42+
db_url: str,
43+
staging_paths: Dict[str, str],
44+
):
3345
self.job_id = job_id
46+
self.workflow_id = workflow_id
3447
self.staging_paths = staging_paths
3548
self.root_dir = root_dir
3649
self.db_url = db_url
3750

3851
@property
3952
def model(self):
53+
if self.workflow_id:
54+
with self.db_session() as session:
55+
workflow = (
56+
session.query(Workflow).filter(Workflow.workflow_id == self.workflow_id).first()
57+
)
58+
self._model = DescribeWorkflow.from_orm(workflow)
59+
return self._model
4060
if self._model is None:
4161
with self.db_session() as session:
4262
job = session.query(Job).filter(Job.job_id == self.job_id).first()
@@ -65,6 +85,18 @@ def process(self):
6585
else:
6686
self.on_complete()
6787

88+
def process_workflow(self):
89+
90+
self.before_start_workflow()
91+
try:
92+
self.execute_workflow()
93+
except CellExecutionError as e:
94+
self.on_failure_workflow(e)
95+
except Exception as e:
96+
self.on_failure_workflow(e)
97+
else:
98+
self.on_complete_workflow()
99+
68100
@abstractmethod
69101
def execute(self):
70102
"""Performs notebook execution,
@@ -74,6 +106,11 @@ def execute(self):
74106
"""
75107
pass
76108

109+
@abstractmethod
110+
def execute_workflow(self):
111+
"""Performs workflow execution"""
112+
pass
113+
77114
@classmethod
78115
@abstractmethod
79116
def supported_features(cls) -> Dict[JobFeature, bool]:
@@ -98,6 +135,15 @@ def before_start(self):
98135
)
99136
session.commit()
100137

138+
def before_start_workflow(self):
139+
"""Called before start of execute"""
140+
workflow = self.model
141+
with self.db_session() as session:
142+
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
143+
{"status": Status.IN_PROGRESS}
144+
)
145+
session.commit()
146+
101147
def on_failure(self, e: Exception):
102148
"""Called after failure of execute"""
103149
job = self.model
@@ -109,6 +155,17 @@ def on_failure(self, e: Exception):
109155

110156
traceback.print_exc()
111157

158+
def on_failure_workflow(self, e: Exception):
159+
"""Called after failure of execute"""
160+
workflow = self.model
161+
with self.db_session() as session:
162+
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
163+
{"status": Status.FAILED, "status_message": str(e)}
164+
)
165+
session.commit()
166+
167+
traceback.print_exc()
168+
112169
def on_complete(self):
113170
"""Called after job is completed"""
114171
job = self.model
@@ -118,10 +175,40 @@ def on_complete(self):
118175
)
119176
session.commit()
120177

178+
def on_complete_workflow(self):
179+
workflow = self.model
180+
with self.db_session() as session:
181+
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
182+
{"status": Status.COMPLETED}
183+
)
184+
session.commit()
185+
121186

122187
class DefaultExecutionManager(ExecutionManager):
123188
"""Default execution manager that executes notebooks"""
124189

190+
@task(task_run_name="{task_id}")
191+
def execute_task(task_id: str):
192+
print(f"Task {task_id} executed")
193+
return task_id
194+
195+
@flow(task_runner=DaskTaskRunner())
196+
def execute_workflow(self):
197+
workflow: DescribeWorkflow = self.model
198+
tasks = {task["id"]: task for task in workflow.tasks}
199+
200+
# create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
201+
@lru_cache(maxsize=None)
202+
def make_task(task_id, execute_task):
203+
deps = tasks[task_id]["dependsOn"]
204+
return execute_task.submit(
205+
task_id, wait_for=[make_task(dep_id, execute_task) for dep_id in deps]
206+
)
207+
208+
final_tasks = [make_task(task_id, self.execute_task) for task_id in tasks]
209+
for future in as_completed(final_tasks):
210+
print(future.result())
211+
125212
def execute(self):
126213
job = self.model
127214

jupyter_scheduler/extension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from traitlets import Bool, Type, Unicode, default
77

88
from jupyter_scheduler.orm import create_tables
9+
from jupyter_scheduler.workflows import WorkflowHandler
910

1011
from .handlers import (
1112
BatchJobHandler,
@@ -35,6 +36,7 @@ class SchedulerApp(ExtensionApp):
3536
(r"scheduler/job_definitions/%s/jobs" % JOB_DEFINITION_ID_REGEX, JobFromDefinitionHandler),
3637
(r"scheduler/runtime_environments", RuntimeEnvironmentsHandler),
3738
(r"scheduler/config", ConfigHandler),
39+
(r"scheduler/worklows", WorkflowHandler),
3840
]
3941

4042
drop_tables = Bool(False, config=True, help="Drop the database tables before starting.")

jupyter_scheduler/orm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ class Job(CommonColumns, Base):
107107
# Any default values specified for new columns will be ignored during the migration process.
108108

109109

110+
class Workflow(Base):
111+
__tablename__ = "workflows"
112+
__table_args__ = {"extend_existing": True}
113+
workflow_id = Column(String(36), primary_key=True, default=generate_uuid)
114+
tasks = Column(JsonType(1024))
115+
status = Column(String(64), default=Status.STOPPED)
116+
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
117+
# Any default values specified for new columns will be ignored during the migration process.
118+
119+
110120
class JobDefinition(CommonColumns, Base):
111121
__tablename__ = "job_definitions"
112122
__table_args__ = {"extend_existing": True}

jupyter_scheduler/scheduler.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
UpdateJob,
4141
UpdateJobDefinition,
4242
)
43-
from jupyter_scheduler.orm import Job, JobDefinition, create_session
43+
from jupyter_scheduler.orm import Job, JobDefinition, Workflow, create_session
4444
from jupyter_scheduler.utils import (
4545
copy_directory,
4646
create_output_directory,
4747
create_output_filename,
4848
)
49+
from jupyter_scheduler.workflows import CreateWorkflow
4950

5051

5152
class BaseScheduler(LoggingConfigurable):
@@ -111,6 +112,10 @@ def create_job(self, model: CreateJob) -> str:
111112
"""
112113
raise NotImplementedError("must be implemented by subclass")
113114

115+
def create_workflow(self, model: CreateWorkflow) -> str:
116+
"""Creates a new workflow record, may trigger execution of the workflow."""
117+
raise NotImplementedError("must be implemented by subclass")
118+
114119
def update_job(self, job_id: str, model: UpdateJob):
115120
"""Updates job metadata in the persistence store,
116121
for example name, status etc. In case of status
@@ -526,6 +531,27 @@ def create_job(self, model: CreateJob) -> str:
526531

527532
return job_id
528533

534+
def create_workflow(self, model: CreateWorkflow) -> str:
535+
536+
with self.db_session() as session:
537+
538+
workflow = Workflow(**model.dict(exclude_none=True))
539+
540+
session.add(workflow)
541+
session.commit()
542+
543+
execution_manager = self.execution_manager_class(
544+
workflow_id=workflow.workflow_id,
545+
root_dir=self.root_dir,
546+
db_url=self.db_url,
547+
)
548+
execution_manager.process_workflow()
549+
session.commit()
550+
551+
workflow_id = workflow.workflow_id
552+
553+
return workflow_id
554+
529555
def update_job(self, job_id: str, model: UpdateJob):
530556
with self.db_session() as session:
531557
session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True))

jupyter_scheduler/workflows.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import json
2+
from typing import List
3+
4+
from jupyter_server.utils import ensure_async
5+
from tornado.web import HTTPError, authenticated
6+
7+
from jupyter_scheduler.exceptions import (
8+
IdempotencyTokenError,
9+
InputUriError,
10+
SchedulerError,
11+
)
12+
from jupyter_scheduler.handlers import (
13+
APIHandler,
14+
ExtensionHandlerMixin,
15+
JobHandlersMixin,
16+
)
17+
from jupyter_scheduler.models import Status
18+
from jupyter_scheduler.pydantic_v1 import BaseModel, ValidationError
19+
20+
21+
class WorkflowHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
22+
@authenticated
23+
async def post(self):
24+
payload = self.get_json_body()
25+
try:
26+
workflow_id = await ensure_async(
27+
self.scheduler.create_workflow(CreateWorkflow(**payload))
28+
)
29+
self.log.info(payload)
30+
print(payload)
31+
except ValidationError as e:
32+
self.log.exception(e)
33+
raise HTTPError(500, str(e)) from e
34+
except InputUriError as e:
35+
self.log.exception(e)
36+
raise HTTPError(500, str(e)) from e
37+
except IdempotencyTokenError as e:
38+
self.log.exception(e)
39+
raise HTTPError(409, str(e)) from e
40+
except SchedulerError as e:
41+
self.log.exception(e)
42+
raise HTTPError(500, str(e)) from e
43+
except Exception as e:
44+
self.log.exception(e)
45+
raise HTTPError(500, "Unexpected error occurred during creation of a workflow.") from e
46+
else:
47+
self.finish(json.dumps(dict(workflow_id=workflow_id)))
48+
49+
50+
class CreateWorkflow(BaseModel):
51+
tasks: List[str]
52+
53+
54+
class DescribeWorkflow(BaseModel):
55+
workflow_id: str
56+
tasks: List[str] = None
57+
status: Status = Status.CREATED

0 commit comments

Comments
 (0)