Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-patch marvin #5

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/control_flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, Field, field_validator

from control_flow.context import ctx
from control_flow.utilities.marvin import patch_marvin

logger = get_logger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ def wrapper(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj):
with ctx(flow=flow_obj), patch_marvin():
return p_fn(*args, **kwargs)

return wrapper
86 changes: 86 additions & 0 deletions src/control_flow/utilities/marvin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import functools
from contextlib import contextmanager
from typing import Any, Callable

import marvin.ai.text
from marvin.client.openai import AsyncMarvinClient
from marvin.settings import temporary_settings as temporary_marvin_settings
from openai.types.chat import ChatCompletion
from prefect import task as prefect_task

from control_flow.utilities.prefect import create_json_artifact

original_classify_async = marvin.classify_async
original_cast_async = marvin.cast_async
original_extract_async = marvin.extract_async
original_generate_async = marvin.generate_async
original_paint_async = marvin.paint_async
original_speak_async = marvin.speak_async
original_transcribe_async = marvin.transcribe_async


class AsyncControlFlowClient(AsyncMarvinClient):
async def generate_chat(self, **kwargs: Any) -> "ChatCompletion":
super_method = super().generate_chat

@prefect_task(task_run_name="Generate OpenAI chat completion")
async def _generate_chat(**kwargs):
messages = kwargs.get("messages", [])
create_json_artifact(key="prompt", data=messages)
response = await super_method(**kwargs)
create_json_artifact(key="response", data=response)
return response

return _generate_chat(**kwargs)


def generate_task(name: str, original_fn: Callable):
@functools.wraps(marvin.classify_async)
async def wrapper(*args, **kwargs):
@prefect_task(name=name)
async def inner(*args, **kwargs):
create_json_artifact(key="args", data=[args, kwargs])
result = await original_fn(*args, **kwargs)
create_json_artifact(key="result", data=result)
return result

# do this to avoid weirdness with async/sync behavior
return inner(*args, **kwargs)

return wrapper


@contextmanager
def patch_marvin():
with temporary_marvin_settings(default_async_client_cls=AsyncControlFlowClient):
try:
marvin.ai.text.classify_async = generate_task(
"marvin.classify", original_classify_async
)
marvin.ai.text.cast_async = generate_task(
"marvin.cast", original_cast_async
)
marvin.ai.text.extract_async = generate_task(
"marvin.extract", original_extract_async
)
marvin.ai.text.generate_async = generate_task(
"marvin.generate", original_generate_async
)
marvin.ai.images.paint_async = generate_task(
"marvin.paint", original_paint_async
)
marvin.ai.audio.speak_async = generate_task(
"marvin.speak", original_speak_async
)
marvin.ai.audio.transcribe_async = generate_task(
"marvin.transcribe", original_transcribe_async
)
yield
finally:
marvin.ai.text.classify_async = original_classify_async
marvin.ai.text.cast_async = original_cast_async
marvin.ai.text.extract_async = original_extract_async
marvin.ai.text.generate_async = original_generate_async
marvin.ai.images.paint_async = original_paint_async
marvin.ai.audio.speak_async = original_speak_async
marvin.ai.audio.transcribe_async = original_transcribe_async
5 changes: 1 addition & 4 deletions src/control_flow/utilities/prefect.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def create_json_artifact(
Create a JSON artifact.
"""

if isinstance(data, str):
json_data = data
else:
json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode()
json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode()

create_markdown_artifact(
key=key,
Expand Down