Skip to content

Commit

Permalink
Add artifact helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 7, 2024
1 parent e725e10 commit d34ab1c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 59 deletions.
83 changes: 24 additions & 59 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
from openai.types.beta.threads.runs import ToolCall
from prefect import get_client as get_prefect_client
from prefect import task as prefect_task
from prefect.artifacts import ArtifactRequest, create_markdown_artifact
from prefect.context import FlowRunContext
from pydantic import BaseModel, Field, field_validator

from control_flow import settings
from control_flow.context import ctx
from control_flow.utilities.prefect import (
create_json_artifact,
create_markdown_artifact,
create_python_artifact,
)

from .flow import AIFlow
from .task import AITask, TaskStatus
Expand Down Expand Up @@ -163,7 +167,7 @@ async def on_tool_call_created(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is created"""

if tool_call.type == "function":
task_run_name = "Preparing arguments for tool call..."
task_run_name = "Prepare arguments for tool call"
else:
task_run_name = f"Tool call: {tool_call.type}"

Expand Down Expand Up @@ -194,19 +198,6 @@ async def on_tool_call_done(self, tool_call: ToolCall) -> None:
task_run_id=task_run.id, state=prefect.states.Completed(), force=True
)

async def _create_artifact(markdown: str, key: str, description: str = None):
"""low-level artifact call because we need to provide the task run ID manually"""
await client.create_artifact(
artifact=ArtifactRequest(
type="markdown",
key=key,
description=description,
task_run_id=task_run.id,
flow_run_id=task_run.flow_run_id,
data=markdown,
)
)

# code interpreter is run as a single call, so we can publish a result artifact
if tool_call.type == "code_interpreter":
# images = []
Expand All @@ -215,28 +206,25 @@ async def _create_artifact(markdown: str, key: str, description: str = None):
# image_path = download_temp_file(output.image.file_id)
# images.append(image_path)

await _create_artifact(
markdown=f"```python\n{tool_call.code_interpreter.input}\n```",
create_python_artifact(
key="code",
code=tool_call.code_interpreter.input,
description="Code executed in the code interpreter",
task_run_id=task_run.id,
)
outputs = "\n\n".join(
[
o.model_dump_json(indent=2)
for o in tool_call.code_interpreter.outputs
]
)
await _create_artifact(
markdown=f"```json\n{outputs}\n```",
create_json_artifact(
key="output",
data=tool_call.code_interpreter.outputs,
description="Output from the code interpreter",
task_run_id=task_run.id,
)

elif tool_call.type == "function":
await _create_artifact(
markdown=f"```json\n{json.dumps(json.loads(tool_call.function.arguments), indent=2)}\n```",
create_json_artifact(
key="arguments",
data=json.dumps(json.loads(tool_call.function.arguments), indent=2),
description=f"Arguments for the `{tool_call.function.name}` tool",
task_run_id=task_run.id,
)


Expand Down Expand Up @@ -374,7 +362,7 @@ async def modified_fn(
passed_args = json.dumps(passed_args, indent=2)
except Exception:
pass
await create_markdown_artifact(
create_markdown_artifact(
markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format(
name=tool.function.name,
description=tool.function.description or "(none provided)",
Expand All @@ -387,7 +375,7 @@ async def modified_fn(

tool.function._python_fn = prefect_task(
modified_fn,
name=f"Tool call: {tool.function.name}",
task_run_name=f"Tool call: {tool.function.name}",
)
final_tools.append(tool)
return final_tools
Expand All @@ -398,7 +386,7 @@ def _get_openai_run_task(self):
This needs to be regenerated each time in case the instructions change.
"""

@prefect_task(name="Execute OpenAI assistant run")
@prefect_task(task_run_name="Run OpenAI assistant")
async def execute_openai_run(
context: dict = None, run_kwargs: dict = None
) -> Run:
Expand All @@ -418,38 +406,15 @@ async def execute_openai_run(
**run_kwargs,
)
await run.run_async()

await create_markdown_artifact(
markdown=Environment.render(
inspect.cleandoc("""
{% for message in run.messages %}
### Message {{ loop.index }}
```json
{{message.model_dump_json(indent=2)}}
```
{% endfor %}
"""),
run=run,
),
create_json_artifact(
key="messages",
data=run.messages,
description="All messages sent and received during the run.",
)
await create_markdown_artifact(
markdown=Environment.render(
inspect.cleandoc("""
{% for step in run.steps %}
### Step {{ loop.index }}
```json
{{step.model_dump_json(indent=2)}}
```
{% endfor %}
"""),
run=run,
),
key="steps",
description="All steps taken during the run.",
create_json_artifact(
key="actions",
data=run.steps,
description="All actions taken by the assistant during the run.",
)
return run

Expand Down
87 changes: 87 additions & 0 deletions src/control_flow/utilities/prefect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any
from uuid import UUID

from marvin.utilities.asyncio import run_sync
from prefect import get_client as get_prefect_client
from prefect.artifacts import ArtifactRequest
from prefect.context import FlowRunContext, TaskRunContext
from pydantic import TypeAdapter


def create_markdown_artifact(
key: str,
markdown: str,
description: str = None,
task_run_id: UUID = None,
flow_run_id: UUID = None,
) -> None:
"""
Create a Markdown artifact.
"""

tr_context = TaskRunContext.get()
fr_context = FlowRunContext.get()

if tr_context:
task_run_id = task_run_id or tr_context.task_run.id
if fr_context:
flow_run_id = flow_run_id or fr_context.flow_run.id

client = get_prefect_client()
run_sync(
client.create_artifact(
artifact=ArtifactRequest(
key=key,
data=markdown,
description=description,
type="markdown",
task_run_id=task_run_id,
flow_run_id=flow_run_id,
)
)
)


def create_json_artifact(
key: str,
data: Any,
description: str = None,
task_run_id: UUID = None,
flow_run_id: UUID = None,
) -> None:
"""
Create a JSON artifact.
"""

if isinstance(data, str):
json_data = data
else:
json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode()

create_markdown_artifact(
key=key,
markdown=f"```json\n{json_data}\n```",
description=description,
task_run_id=task_run_id,
flow_run_id=flow_run_id,
)


def create_python_artifact(
key: str,
code: str,
description: str = None,
task_run_id: UUID = None,
flow_run_id: UUID = None,
) -> None:
"""
Create a Python artifact.
"""

create_markdown_artifact(
key=key,
markdown=f"```python\n{code}\n```",
description=description,
task_run_id=task_run_id,
flow_run_id=flow_run_id,
)

0 comments on commit d34ab1c

Please sign in to comment.