1
1
import base64
2
2
import json
3
+ import os
3
4
import typing
4
5
from dataclasses import dataclass
5
6
from typing import Any , Callable , Dict , Optional
6
7
7
8
import yaml
8
- from flytekitplugins .ray .models import HeadGroupSpec , RayCluster , RayJob , WorkerGroupSpec
9
+ from flytekitplugins .ray .models import (
10
+ HeadGroupSpec ,
11
+ RayCluster ,
12
+ RayJob ,
13
+ WorkerGroupSpec ,
14
+ )
9
15
from google .protobuf .json_format import MessageToDict
10
16
11
17
from flytekit import lazy_module
12
18
from flytekit .configuration import SerializationSettings
13
- from flytekit .core .context_manager import ExecutionParameters
19
+ from flytekit .core .context_manager import ExecutionParameters , FlyteContextManager
14
20
from flytekit .core .python_function_task import PythonFunctionTask
15
21
from flytekit .extend import TaskPlugins
16
22
@@ -40,6 +46,7 @@ class RayJobConfig:
40
46
address : typing .Optional [str ] = None
41
47
shutdown_after_job_finishes : bool = False
42
48
ttl_seconds_after_finished : typing .Optional [int ] = None
49
+ excludes_working_dir : typing .Optional [typing .List [str ]] = None
43
50
44
51
45
52
class RayFunctionTask (PythonFunctionTask ):
@@ -50,11 +57,30 @@ class RayFunctionTask(PythonFunctionTask):
50
57
_RAY_TASK_TYPE = "ray"
51
58
52
59
def __init__ (self , task_config : RayJobConfig , task_function : Callable , ** kwargs ):
53
- super ().__init__ (task_config = task_config , task_type = self ._RAY_TASK_TYPE , task_function = task_function , ** kwargs )
60
+ super ().__init__ (
61
+ task_config = task_config ,
62
+ task_type = self ._RAY_TASK_TYPE ,
63
+ task_function = task_function ,
64
+ ** kwargs ,
65
+ )
54
66
self ._task_config = task_config
55
67
56
68
def pre_execute (self , user_params : ExecutionParameters ) -> ExecutionParameters :
57
- ray .init (address = self ._task_config .address )
69
+ init_params = {"address" : self ._task_config .address }
70
+
71
+ ctx = FlyteContextManager .current_context ()
72
+ if not ctx .execution_state .is_local_execution ():
73
+ working_dir = os .getcwd ()
74
+ init_params ["runtime_env" ] = {
75
+ "working_dir" : working_dir ,
76
+ "excludes" : ["script_mode.tar.gz" , "fast*.tar.gz" ],
77
+ }
78
+
79
+ cfg = self ._task_config
80
+ if cfg .excludes_working_dir :
81
+ init_params ["runtime_env" ]["excludes" ].extend (cfg .excludes_working_dir )
82
+
83
+ ray .init (** init_params )
58
84
return user_params
59
85
60
86
def get_custom (self , settings : SerializationSettings ) -> Optional [Dict [str , Any ]]:
@@ -67,12 +93,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
67
93
68
94
ray_job = RayJob (
69
95
ray_cluster = RayCluster (
70
- head_group_spec = HeadGroupSpec (cfg .head_node_config .ray_start_params ) if cfg .head_node_config else None ,
96
+ head_group_spec = (
97
+ HeadGroupSpec (cfg .head_node_config .ray_start_params ) if cfg .head_node_config else None
98
+ ),
71
99
worker_group_spec = [
72
- WorkerGroupSpec (c .group_name , c .replicas , c .min_replicas , c .max_replicas , c .ray_start_params )
100
+ WorkerGroupSpec (
101
+ c .group_name ,
102
+ c .replicas ,
103
+ c .min_replicas ,
104
+ c .max_replicas ,
105
+ c .ray_start_params ,
106
+ )
73
107
for c in cfg .worker_node_config
74
108
],
75
- enable_autoscaling = cfg .enable_autoscaling if cfg .enable_autoscaling else False ,
109
+ enable_autoscaling = ( cfg .enable_autoscaling if cfg .enable_autoscaling else False ) ,
76
110
),
77
111
runtime_env = runtime_env ,
78
112
runtime_env_yaml = runtime_env_yaml ,
0 commit comments