Skip to content

Commit

Permalink
Auto-patch marvin
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 7, 2024
1 parent 1ba7e2c commit 3e2854a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/control_flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, Field, field_validator

from control_flow.context import ctx
from control_flow.utilities.marvin import patch_marvin

logger = get_logger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ def wrapper(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj):
with ctx(flow=flow_obj), patch_marvin():
return p_fn(*args, **kwargs)

return wrapper
86 changes: 86 additions & 0 deletions src/control_flow/utilities/marvin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import functools
from contextlib import contextmanager
from typing import Any, Callable

import marvin.ai.text
from marvin.client.openai import AsyncMarvinClient
from marvin.settings import temporary_settings as temporary_marvin_settings
from openai.types.chat import ChatCompletion
from prefect import task as prefect_task

from control_flow.utilities.prefect import create_json_artifact

original_classify_async = marvin.classify_async
original_cast_async = marvin.cast_async
original_extract_async = marvin.extract_async
original_generate_async = marvin.generate_async
original_paint_async = marvin.paint_async
original_speak_async = marvin.speak_async
original_transcribe_async = marvin.transcribe_async


class AsyncControlFlowClient(AsyncMarvinClient):
async def generate_chat(self, **kwargs: Any) -> "ChatCompletion":
super_method = super().generate_chat

@prefect_task(task_run_name="Generate OpenAI chat completion")
async def _generate_chat(**kwargs):
messages = kwargs.get("messages", [])
create_json_artifact(key="prompt", data=messages)
response = await super_method(**kwargs)
create_json_artifact(key="response", data=response)
return response

return _generate_chat(**kwargs)


def generate_task(name: str, original_fn: Callable):
@functools.wraps(marvin.classify_async)
async def wrapper(*args, **kwargs):
@prefect_task(name=name)
async def inner(*args, **kwargs):
create_json_artifact(key="args", data=[args, kwargs])
result = await original_fn(*args, **kwargs)
create_json_artifact(key="result", data=result)
return result

# do this to avoid weirdness with async/sync behavior
return inner(*args, **kwargs)

return wrapper


@contextmanager
def patch_marvin():
with temporary_marvin_settings(default_async_client_cls=AsyncControlFlowClient):
try:
marvin.ai.text.classify_async = generate_task(
"marvin.classify", original_classify_async
)
marvin.ai.text.cast_async = generate_task(
"marvin.cast", original_cast_async
)
marvin.ai.text.extract_async = generate_task(
"marvin.extract", original_extract_async
)
marvin.ai.text.generate_async = generate_task(
"marvin.generate", original_generate_async
)
marvin.ai.images.paint_async = generate_task(
"marvin.paint", original_paint_async
)
marvin.ai.audio.speak_async = generate_task(
"marvin.speak", original_speak_async
)
marvin.ai.audio.transcribe_async = generate_task(
"marvin.transcribe", original_transcribe_async
)
yield
finally:
marvin.ai.text.classify_async = original_classify_async
marvin.ai.text.cast_async = original_cast_async
marvin.ai.text.extract_async = original_extract_async
marvin.ai.text.generate_async = original_generate_async
marvin.ai.images.paint_async = original_paint_async
marvin.ai.audio.speak_async = original_speak_async
marvin.ai.audio.transcribe_async = original_transcribe_async
5 changes: 1 addition & 4 deletions src/control_flow/utilities/prefect.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def create_json_artifact(
Create a JSON artifact.
"""

if isinstance(data, str):
json_data = data
else:
json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode()
json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode()

create_markdown_artifact(
key=key,
Expand Down

0 comments on commit 3e2854a

Please sign in to comment.