Skip to content

Commit 5bc5d5c

Browse files
authored
Enable Ray Fast Register (#2606)
Signed-off-by: Jan Fiedler <jan@union.ai>
1 parent df3ab4c commit 5bc5d5c

File tree

1 file changed

+41
-7
lines changed
  • plugins/flytekit-ray/flytekitplugins/ray

1 file changed

+41
-7
lines changed

plugins/flytekit-ray/flytekitplugins/ray/task.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import base64
22
import json
3+
import os
34
import typing
45
from dataclasses import dataclass
56
from typing import Any, Callable, Dict, Optional
67

78
import yaml
8-
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
9+
from flytekitplugins.ray.models import (
10+
HeadGroupSpec,
11+
RayCluster,
12+
RayJob,
13+
WorkerGroupSpec,
14+
)
915
from google.protobuf.json_format import MessageToDict
1016

1117
from flytekit import lazy_module
1218
from flytekit.configuration import SerializationSettings
13-
from flytekit.core.context_manager import ExecutionParameters
19+
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
1420
from flytekit.core.python_function_task import PythonFunctionTask
1521
from flytekit.extend import TaskPlugins
1622

@@ -40,6 +46,7 @@ class RayJobConfig:
4046
address: typing.Optional[str] = None
4147
shutdown_after_job_finishes: bool = False
4248
ttl_seconds_after_finished: typing.Optional[int] = None
49+
excludes_working_dir: typing.Optional[typing.List[str]] = None
4350

4451

4552
class RayFunctionTask(PythonFunctionTask):
@@ -50,11 +57,30 @@ class RayFunctionTask(PythonFunctionTask):
5057
_RAY_TASK_TYPE = "ray"
5158

5259
def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
53-
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
60+
super().__init__(
61+
task_config=task_config,
62+
task_type=self._RAY_TASK_TYPE,
63+
task_function=task_function,
64+
**kwargs,
65+
)
5466
self._task_config = task_config
5567

5668
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
57-
ray.init(address=self._task_config.address)
69+
init_params = {"address": self._task_config.address}
70+
71+
ctx = FlyteContextManager.current_context()
72+
if not ctx.execution_state.is_local_execution():
73+
working_dir = os.getcwd()
74+
init_params["runtime_env"] = {
75+
"working_dir": working_dir,
76+
"excludes": ["script_mode.tar.gz", "fast*.tar.gz"],
77+
}
78+
79+
cfg = self._task_config
80+
if cfg.excludes_working_dir:
81+
init_params["runtime_env"]["excludes"].extend(cfg.excludes_working_dir)
82+
83+
ray.init(**init_params)
5884
return user_params
5985

6086
def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
@@ -67,12 +93,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
6793

6894
ray_job = RayJob(
6995
ray_cluster=RayCluster(
70-
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
96+
head_group_spec=(
97+
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
98+
),
7199
worker_group_spec=[
72-
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
100+
WorkerGroupSpec(
101+
c.group_name,
102+
c.replicas,
103+
c.min_replicas,
104+
c.max_replicas,
105+
c.ray_start_params,
106+
)
73107
for c in cfg.worker_node_config
74108
],
75-
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
109+
enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False),
76110
),
77111
runtime_env=runtime_env,
78112
runtime_env_yaml=runtime_env_yaml,

0 commit comments

Comments
 (0)