@@ -90,9 +90,6 @@ class Resource:
90
90
outputs : Optional [Union [LiteralMap , typing .Dict [str , Any ]]] = None
91
91
92
92
93
- T = typing .TypeVar ("T" , bound = ResourceMeta )
94
-
95
-
96
93
class AgentBase (ABC ):
97
94
name = "Base Agent"
98
95
@@ -127,7 +124,7 @@ def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs
127
124
raise NotImplementedError
128
125
129
126
130
- class AsyncAgentBase (AgentBase , typing . Generic [ T ] ):
127
+ class AsyncAgentBase (AgentBase ):
131
128
"""
132
129
This is the base class for all async agents. It defines the interface that all agents must implement.
133
130
The agent service is responsible for invoking agents. The propeller will communicate with the agent service
@@ -139,7 +136,7 @@ class AsyncAgentBase(AgentBase, typing.Generic[T]):
139
136
140
137
name = "Base Async Agent"
141
138
142
- def __init__ (self , metadata_type : typing . Type [ T ] , ** kwargs ):
139
+ def __init__ (self , metadata_type : ResourceMeta , ** kwargs ):
143
140
super ().__init__ (** kwargs )
144
141
self ._metadata_type = metadata_type
145
142
@@ -148,14 +145,14 @@ def metadata_type(self) -> ResourceMeta:
148
145
return self ._metadata_type
149
146
150
147
@abstractmethod
151
- def create (self , task_template : TaskTemplate , inputs : Optional [LiteralMap ], ** kwargs ) -> T :
148
+ def create (self , task_template : TaskTemplate , inputs : Optional [LiteralMap ], ** kwargs ) -> ResourceMeta :
152
149
"""
153
150
Return a resource meta that can be used to get the status of the task.
154
151
"""
155
152
raise NotImplementedError
156
153
157
154
@abstractmethod
158
- def get (self , resource_meta : T , ** kwargs ) -> Resource :
155
+ def get (self , resource_meta : ResourceMeta , ** kwargs ) -> Resource :
159
156
"""
160
157
Return the status of the task, and return the outputs in some cases. For example, bigquery job
161
158
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
@@ -164,7 +161,7 @@ def get(self, resource_meta: T, **kwargs) -> Resource:
164
161
raise NotImplementedError
165
162
166
163
@abstractmethod
167
- def delete (self , resource_meta : T , ** kwargs ):
164
+ def delete (self , resource_meta : ResourceMeta , ** kwargs ):
168
165
"""
169
166
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
170
167
"""
@@ -231,9 +228,7 @@ class SyncAgentExecutorMixin:
231
228
Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system.
232
229
"""
233
230
234
- T = typing .TypeVar ("T" , "SyncAgentExecutorMixin" , PythonTask )
235
-
236
- def execute (self : T , ** kwargs ) -> LiteralMap :
231
+ def execute (self : PythonTask , ** kwargs ) -> LiteralMap :
237
232
from flytekit .tools .translator import get_serializable
238
233
239
234
ctx = FlyteContext .current_context ()
@@ -250,7 +245,9 @@ def execute(self: T, **kwargs) -> LiteralMap:
250
245
return TypeEngine .dict_to_literal_map (ctx , resource .outputs )
251
246
return resource .outputs
252
247
253
- async def _do (self : T , agent : SyncAgentBase , template : TaskTemplate , inputs : Dict [str , Any ] = None ) -> Resource :
248
+ async def _do (
249
+ self : PythonTask , agent : SyncAgentBase , template : TaskTemplate , inputs : Dict [str , Any ] = None
250
+ ) -> Resource :
254
251
try :
255
252
ctx = FlyteContext .current_context ()
256
253
literal_map = TypeEngine .dict_to_literal_map (ctx , inputs or {}, self .get_input_types ())
@@ -267,12 +264,10 @@ class AsyncAgentExecutorMixin:
267
264
Asynchronous tasks are tasks that take a long time to complete, such as running a query.
268
265
"""
269
266
270
- T = typing .TypeVar ("T" , "AsyncAgentExecutorMixin" , PythonTask )
271
-
272
267
_clean_up_task : coroutine = None
273
268
_agent : AsyncAgentBase = None
274
269
275
- def execute (self : T , ** kwargs ) -> LiteralMap :
270
+ def execute (self : PythonTask , ** kwargs ) -> LiteralMap :
276
271
ctx = FlyteContext .current_context ()
277
272
ss = ctx .serialization_settings or SerializationSettings (ImageConfig ())
278
273
output_prefix = ctx .file_access .get_random_remote_directory ()
@@ -301,7 +296,7 @@ def execute(self: T, **kwargs) -> LiteralMap:
301
296
return resource .outputs
302
297
303
298
async def _create (
304
- self : T , task_template : TaskTemplate , output_prefix : str , inputs : Dict [str , Any ] = None
299
+ self : PythonTask , task_template : TaskTemplate , output_prefix : str , inputs : Dict [str , Any ] = None
305
300
) -> ResourceMeta :
306
301
ctx = FlyteContext .current_context ()
307
302
@@ -322,7 +317,7 @@ async def _create(
322
317
signal .signal (signal .SIGINT , partial (self .signal_handler , resource_meta )) # type: ignore
323
318
return resource_meta
324
319
325
- async def _get (self : T , resource_meta : ResourceMeta ) -> Resource :
320
+ async def _get (self : PythonTask , resource_meta : ResourceMeta ) -> Resource :
326
321
phase = TaskExecution .RUNNING
327
322
328
323
progress = Progress (transient = True )
0 commit comments