diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 04e480e67b..72bbcb698e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -204,17 +204,20 @@ async def _create( literals = inputs or {} for k, v in inputs.items(): literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) + + literal_map = LiteralMap(literals) + if isinstance(self, PythonFunctionTask): # Write the inputs to a remote file, so that the remote task can read the inputs from this file. path = ctx.file_access.get_random_local_path() - utils.write_proto_to_file(LiteralMap(literals).to_flyte_idl(), path) + utils.write_proto_to_file(literal_map.to_flyte_idl(), path) ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) if self._agent.asynchronous: - res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) + res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, literal_map) else: - res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) + res = self._agent.create(grpc_ctx, output_prefix, task_template, literal_map) signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore return res