Skip to content

Commit

Permalink
fix artifact templates
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 7, 2024
1 parent c6aa552 commit e725e10
Showing 1 changed file with 33 additions and 52 deletions.
85 changes: 33 additions & 52 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
@@ -33,35 +33,7 @@

TEMP_THREADS = {}

TOOL_CALL_CODE_INTERPRETER_TEMPLATE = inspect.cleandoc(
"""
## Tool call: code interpreter
### Code
```python
{code}
```
### Result
```json
{result}
```
"""
)

TOOL_CALL_FUNCTION_ARGS_TEMPLATE = inspect.cleandoc(
"""
## Tool call: {name}
### Arguments
```json
{args}
```
"""
)
TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc(
"""
## Tool call: {name}
@@ -213,6 +185,7 @@ async def on_tool_call_created(self, tool_call: ToolCall) -> None:

async def on_tool_call_done(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is done"""

client = get_prefect_client()
task_run = self.tool_calls.get(tool_call.id)
if not task_run:
@@ -221,6 +194,19 @@ async def on_tool_call_done(self, tool_call: ToolCall) -> None:
task_run_id=task_run.id, state=prefect.states.Completed(), force=True
)

async def _create_artifact(markdown: str, key: str, description: str = None):
"""low-level artifact call because we need to provide the task run ID manually"""
await client.create_artifact(
artifact=ArtifactRequest(
type="markdown",
key=key,
description=description,
task_run_id=task_run.id,
flow_run_id=task_run.flow_run_id,
data=markdown,
)
)

# code interpreter is run as a single call, so we can publish a result artifact
if tool_call.type == "code_interpreter":
# images = []
@@ -229,33 +215,29 @@ async def on_tool_call_done(self, tool_call: ToolCall) -> None:
# image_path = download_temp_file(output.image.file_id)
# images.append(image_path)

markdown = TOOL_CALL_CODE_INTERPRETER_TEMPLATE.format(
code=tool_call.code_interpreter.input,
result=json.dumps(
[
o.model_dump(mode="json")
for o in tool_call.code_interpreter.outputs
],
indent=2,
),
await _create_artifact(
markdown=f"```python\n{tool_call.code_interpreter.input}\n```",
key="code",
description="Code executed in the code interpreter",
)
elif tool_call.type == "function":
markdown = TOOL_CALL_FUNCTION_ARGS_TEMPLATE.format(
name=tool_call.function.name,
args=tool_call.function.arguments,
outputs = "\n\n".join(
[
o.model_dump_json(indent=2)
for o in tool_call.code_interpreter.outputs
]
)
await _create_artifact(
markdown=f"```json\n{outputs}\n```",
key="output",
description="Output from the code interpreter",
)

# low level artifact call because we need to provide the task run ID manually
return await client.create_artifact(
artifact=ArtifactRequest(
type="markdown",
key="result",
description="Code interpreter result",
task_run_id=task_run.id,
flow_run_id=task_run.flow_run_id,
data=markdown,
elif tool_call.type == "function":
await _create_artifact(
markdown=f"```json\n{json.dumps(json.loads(tool_call.function.arguments), indent=2)}\n```",
key="arguments",
description=f"Arguments for the `{tool_call.function.name}` tool",
)
)


def talk_to_human(message: str, get_response: bool = True) -> str:
@@ -487,7 +469,6 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:
any(t.status == TaskStatus.PENDING for t in self.tasks)
and counter < settings.max_agent_iterations
):
breakpoint()
openai_run(context=context, run_kwargs=run_kwargs)
counter += 1

0 comments on commit e725e10

Please sign in to comment.