This repository has been archived by the owner on Sep 13, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor for super clear agent protocol implemenation (#15)
- Loading branch information
Showing
10 changed files
with
757 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,220 @@ | ||
import asyncio | ||
import os | ||
import typing | ||
|
||
import autogpt.utils | ||
from autogpt.agent_protocol import Agent, Artifact, Step, Task, TaskDB | ||
from fastapi import APIRouter, FastAPI, Response, UploadFile | ||
from fastapi.responses import FileResponse, JSONResponse | ||
from hypercorn.asyncio import serve | ||
from hypercorn.config import Config | ||
|
||
from .db import AgentDB | ||
from .middlewares import AgentMiddleware | ||
from .routes.agent_protocol import base_router | ||
from .schema import Artifact, Status, Step, StepRequestBody, Task, TaskRequestBody | ||
from .utils import run | ||
from .workspace import Workspace | ||
|
||
|
||
class AutoGPT(Agent): | ||
def __init__(self, db: TaskDB, workspace: Workspace) -> None: | ||
super().__init__(db) | ||
class Agent: | ||
def __init__(self, database: AgentDB, workspace: Workspace): | ||
self.db = database | ||
self.workspace = workspace | ||
|
||
async def create_task(self, task: Task) -> None: | ||
print(f"task: {task.input}") | ||
def start(self, port: int = 8000, router: APIRouter = base_router): | ||
""" | ||
Start the agent server. | ||
""" | ||
config = Config() | ||
config.bind = [f"localhost:{port}"] | ||
app = FastAPI( | ||
title="Auto-GPT Forge", | ||
description="Modified version of The Agent Protocol.", | ||
version="v0.4", | ||
) | ||
app.include_router(router) | ||
app.add_middleware(AgentMiddleware, agent=self) | ||
asyncio.run(serve(app, config)) | ||
|
||
async def create_task(self, task_request: TaskRequestBody) -> Task: | ||
""" | ||
Create a task for the agent. | ||
""" | ||
try: | ||
task = await self.db.create_task( | ||
input=task_request.input if task_request.input else None, | ||
additional_input=task_request.additional_input | ||
if task_request.additional_input | ||
else None, | ||
) | ||
print(task) | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
print(task) | ||
return task | ||
|
||
async def list_tasks(self) -> typing.List[str]: | ||
""" | ||
List the IDs of all tasks that the agent has created. | ||
""" | ||
try: | ||
task_ids = [task.task_id for task in await self.db.list_tasks()] | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
return task_ids | ||
|
||
async def get_task(self, task_id: str) -> Task: | ||
""" | ||
Get a task by ID. | ||
""" | ||
if not task_id: | ||
return Response(status_code=400, content="Task ID is required.") | ||
if not isinstance(task_id, str): | ||
return Response(status_code=400, content="Task ID must be a string.") | ||
try: | ||
task = await self.db.get_task(task_id) | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
return task | ||
|
||
async def run_step(self, step: Step) -> Step: | ||
artifacts = autogpt.utils.run(step.input) | ||
for artifact in artifacts: | ||
art = await self.db.create_artifact( | ||
task_id=step.task_id, | ||
file_name=artifact["file_name"], | ||
uri=artifact["uri"], | ||
agent_created=True, | ||
step_id=step.step_id, | ||
async def list_steps(self, task_id: str) -> typing.List[str]: | ||
""" | ||
List the IDs of all steps that the task has created. | ||
""" | ||
if not task_id: | ||
return Response(status_code=400, content="Task ID is required.") | ||
if not isinstance(task_id, str): | ||
return Response(status_code=400, content="Task ID must be a string.") | ||
try: | ||
steps_ids = [step.step_id for step in await self.db.list_steps(task_id)] | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
return steps_ids | ||
|
||
async def create_and_execute_step( | ||
self, task_id: str, step_request: StepRequestBody | ||
) -> Step: | ||
""" | ||
Create a step for the task. | ||
""" | ||
if step_request.input != "y": | ||
step = await self.db.create_step( | ||
task_id=task_id, | ||
input=step_request.input if step_request else None, | ||
additional_properties=step_request.additional_input | ||
if step_request | ||
else None, | ||
) | ||
# utils.run | ||
artifacts = run(step.input) | ||
for artifact in artifacts: | ||
art = await self.db.create_artifact( | ||
task_id=step.task_id, | ||
file_name=artifact["file_name"], | ||
uri=artifact["uri"], | ||
agent_created=True, | ||
step_id=step.step_id, | ||
) | ||
assert isinstance( | ||
art, Artifact | ||
), f"Artifact not instance of Artifact {type(art)}" | ||
step.artifacts.append(art) | ||
step.status = "completed" | ||
else: | ||
steps = await self.db.list_steps(task_id) | ||
artifacts = await self.db.list_artifacts(task_id) | ||
step = steps[-1] | ||
step.artifacts = artifacts | ||
step.output = "No more steps to run." | ||
step.is_last = True | ||
if isinstance(step.status, Status): | ||
step.status = step.status.value | ||
step.output = "Done some work" | ||
return JSONResponse(content=step.dict(), status_code=200) | ||
|
||
async def get_step(self, task_id: str, step_id: str) -> Step: | ||
""" | ||
Get a step by ID. | ||
""" | ||
if not task_id or not step_id: | ||
return Response( | ||
status_code=400, content="Task ID and step ID are required." | ||
) | ||
assert isinstance( | ||
art, Artifact | ||
), f"Artifact not isntance of Artifact {type(art)}" | ||
step.artifacts.append(art) | ||
step.status = "completed" | ||
if not isinstance(task_id, str) or not isinstance(step_id, str): | ||
return Response( | ||
status_code=400, content="Task ID and step ID must be strings." | ||
) | ||
try: | ||
step = await self.db.get_step(task_id, step_id) | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
return step | ||
|
||
async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> bytes: | ||
async def list_artifacts(self, task_id: str) -> typing.List[Artifact]: | ||
""" | ||
Retrieve the artifact data from wherever it is stored and return it as bytes. | ||
List the artifacts that the task has created. | ||
""" | ||
if not artifact.uri.startswith("file://"): | ||
raise NotImplementedError("Loading from uri not implemented") | ||
file_path = artifact.uri.split("file://")[1] | ||
if not self.workspace.exists(file_path): | ||
raise FileNotFoundError(f"File {file_path} not found in workspace") | ||
return self.workspace.read(file_path) | ||
if not task_id: | ||
return Response(status_code=400, content="Task ID is required.") | ||
if not isinstance(task_id, str): | ||
return Response(status_code=400, content="Task ID must be a string.") | ||
try: | ||
artifacts = await self.db.list_artifacts(task_id) | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
return artifacts | ||
|
||
async def save_artifact( | ||
self, task_id: str, artifact: Artifact, data: bytes | ||
async def create_artifact( | ||
self, | ||
task_id: str, | ||
file: UploadFile | None = None, | ||
uri: str | None = None, | ||
) -> Artifact: | ||
""" | ||
Save the artifact data to the agent's workspace, loading from uri if bytes are not available. | ||
Create an artifact for the task. | ||
""" | ||
assert ( | ||
data is not None and artifact.uri is not None | ||
), "Data or Artifact uri must be set" | ||
if not file and not uri: | ||
return Response(status_code=400, content="No file or uri provided") | ||
data = None | ||
if not uri: | ||
file_name = file.filename or str(uuid4()) | ||
try: | ||
data = b"" | ||
while contents := file.file.read(1024 * 1024): | ||
data += contents | ||
except Exception as e: | ||
return Response(status_code=500, content=str(e)) | ||
|
||
if data is not None: | ||
file_path = os.path.join(task_id / artifact.file_name) | ||
file_path = os.path.join(task_id / file_name) | ||
self.write(file_path, data) | ||
artifact.uri = f"file://{file_path}" | ||
self.db.save_artifact(task_id, artifact) | ||
else: | ||
raise NotImplementedError("Loading from uri not implemented") | ||
|
||
artifact = await self.create_artifact( | ||
task_id=task_id, | ||
file_name=file_name, | ||
uri=f"file://{file_path}", | ||
agent_created=False, | ||
) | ||
|
||
return artifact | ||
|
||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: | ||
""" | ||
Get an artifact by ID. | ||
""" | ||
artifact = await self.db.get_artifact(task_id, artifact_id) | ||
if not artifact.uri.startswith("file://"): | ||
return Response( | ||
status_code=500, content="Loading from none file uri not implemented" | ||
) | ||
file_path = artifact.uri.split("file://")[1] | ||
if not self.workspace.exists(file_path): | ||
return Response(status_code=500, content="File not found") | ||
retrieved_artifact = self.workspace.read(file_path) | ||
path = artifact.file_name | ||
with open(path, "wb") as f: | ||
f.write(retrieved_artifact) | ||
return FileResponse( | ||
# Note: mimetype is guessed in the FileResponse constructor | ||
path=path, | ||
filename=artifact.file_name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.