Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Fixed artifacts handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Swiftyos committed Aug 30, 2023
1 parent f56ca3c commit 60dcaa1
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 62 deletions.
2 changes: 1 addition & 1 deletion autogpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
port = os.getenv("PORT")

database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=False)
database = autogpt.sdk.db.AgentDB(database_name, debug_enabled=True)
agent = autogpt.agent.AutoGPTAgent(database=database, workspace=workspace)

agent.start(port=port, router=router)
41 changes: 19 additions & 22 deletions autogpt/sdk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,37 +179,33 @@ async def list_artifacts(
raise

async def create_artifact(
self,
task_id: str,
artifact_upload: ArtifactUpload,
self, task_id: str, file: UploadFile, relative_path: str
) -> Artifact:
"""
Create an artifact for the task.
"""
data = None
file_name = artifact_upload.file.filename or str(uuid4())
file_name = file.filename or str(uuid4())
try:
data = b""
while contents := artifact_upload.file.file.read(1024 * 1024):
while contents := file.file.read(1024 * 1024):
data += contents
except Exception as e:
raise
# Check if relative path ends with filename
if artifact_upload.relative_path.endswith(file_name):
file_path = os.path.join(task_id, artifact_upload.relative_path)
else:
file_path = os.path.join(task_id, artifact_upload.relative_path, file_name)
# Check if relative path ends with filename
if relative_path.endswith(file_name):
file_path = relative_path
else:
file_path = os.path.join(relative_path, file_name)

self.write(file_path, data)
self.db.save_artifact(task_id, artifact)

artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=artifact_upload.relative_path,
agent_created=False,
)
self.workspace.write(task_id, file_path, data)

artifact = await self.db.create_artifact(
task_id=task_id,
file_name=file_name,
relative_path=relative_path,
agent_created=False,
)
except Exception as e:
raise
return artifact

async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
Expand All @@ -218,7 +214,8 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
"""
try:
artifact = await self.db.get_artifact(artifact_id)
retrieved_artifact = await self.load_from_uri(artifact.uri, artifact_id)
file_path = os.path.join(artifact.relative_path, artifact.file_name)
retrieved_artifact = self.workspace.read(task_id=task_id, path=file_path)
path = artifact.file_name
with open(path, "wb") as f:
f.write(retrieved_artifact)
Expand Down
11 changes: 7 additions & 4 deletions autogpt/sdk/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .errors import NotFoundError
from .forge_log import CustomLogger
from .schema import Artifact, Pagination, Status, Step, Task, TaskInput
from .schema import Artifact, Pagination, Status, Step, StepRequestBody, Task, TaskInput

LOG = CustomLogger(__name__)

Expand Down Expand Up @@ -124,6 +124,7 @@ def convert_to_artifact(artifact_model: ArtifactModel) -> Artifact:
modified_at=artifact_model.modified_at,
agent_created=artifact_model.agent_created,
relative_path=artifact_model.relative_path,
file_name=artifact_model.file_name,
)


Expand All @@ -149,7 +150,9 @@ async def create_task(
new_task = TaskModel(
task_id=str(uuid.uuid4()),
input=input,
additional_input=additional_input,
additional_input=additional_input.json()
if additional_input
else {},
)
session.add(new_task)
session.commit()
Expand All @@ -169,7 +172,7 @@ async def create_task(
async def create_step(
self,
task_id: str,
input: str,
input: StepRequestBody,
is_last: bool = False,
additional_input: Optional[Dict[str, Any]] = {},
) -> Step:
Expand All @@ -180,7 +183,7 @@ async def create_step(
new_step = StepModel(
task_id=task_id,
step_id=str(uuid.uuid4()),
name=input.name,
name=input.input,
input=input.input,
status="created",
is_last=is_last,
Expand Down
30 changes: 20 additions & 10 deletions autogpt/sdk/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def test_convert_to_task():
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
)
Expand Down Expand Up @@ -147,6 +148,7 @@ async def test_convert_to_step():
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
],
is_last=False,
Expand All @@ -170,6 +172,7 @@ async def test_convert_to_artifact():
modified_at=now,
relative_path="file:///path/to/main.py",
agent_created=True,
file_name="main.py",
)
artifact = convert_to_artifact(artifact_model)
assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56"
Expand Down Expand Up @@ -208,25 +211,27 @@ async def test_get_task_not_found():
os.remove(db_name.split("///")[1])


@pytest.mark.skip
@pytest.mark.asyncio
async def test_create_and_get_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
task = await agent_db.create_task("task_input")
step = await agent_db.create_step(task.task_id, "step_name")
step_input = StepInput(type="python/code")
request = StepRequestBody(input="test_input debug", additional_input=step_input)
step = await agent_db.create_step(task.task_id, request)
step = await agent_db.get_step(task.task_id, step.step_id)
assert step.name == "step_name"
assert step.input == "test_input debug"
os.remove(db_name.split("///")[1])


@pytest.mark.skip
@pytest.mark.asyncio
async def test_updating_step():
db_name = "sqlite:///test_db.sqlite3"
agent_db = AgentDB(db_name)
created_task = await agent_db.create_task("task_input")
created_step = await agent_db.create_step(created_task.task_id, "step_name")
step_input = StepInput(type="python/code")
request = StepRequestBody(input="test_input debug", additional_input=step_input)
created_step = await agent_db.create_step(created_task.task_id, request)
await agent_db.update_step(created_task.task_id, created_step.step_id, "completed")

step = await agent_db.get_step(created_task.task_id, created_step.step_id)
Expand All @@ -243,15 +248,17 @@ async def test_get_step_not_found():
os.remove(db_name.split("///")[1])


@pytest.mark.skip
@pytest.mark.asyncio
async def test_get_artifact():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)

# Given: A task and its corresponding artifact
task = await db.create_task("test_input debug")
step = await db.create_step(task.task_id, "step_name")
step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)

step = await db.create_step(task.task_id, requst)

# Create an artifact
artifact = await db.create_artifact(
Expand Down Expand Up @@ -294,16 +301,19 @@ async def test_list_tasks():
os.remove(db_name.split("///")[1])


@pytest.mark.skip
@pytest.mark.asyncio
async def test_list_steps():
db_name = "sqlite:///test_db.sqlite3"
db = AgentDB(db_name)

step_input = StepInput(type="python/code")
requst = StepRequestBody(input="test_input debug", additional_input=step_input)

# Given: A task and multiple steps for that task
task = await db.create_task("test_input")
step1 = await db.create_step(task.task_id, "step_1")
step2 = await db.create_step(task.task_id, "step_2")
step1 = await db.create_step(task.task_id, requst)
requst = StepRequestBody(input="step two", additional_input=step_input)
step2 = await db.create_step(task.task_id, requst)

# When: All steps for the task are fetched
fetched_steps, pagination = await db.list_steps(task.task_id)
Expand Down
2 changes: 1 addition & 1 deletion autogpt/sdk/forge_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(self, name: str, level: int = logging.NOTSET):
},
root={
"handlers": ["h"],
"level": logging.WARNING,
"level": logging.DEBUG,
},
loggers={
"autogpt": {
Expand Down
44 changes: 24 additions & 20 deletions autogpt/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,8 @@ async def create_agent_task(request: Request, task_request: TaskRequestBody) ->
status_code=200,
media_type="application/json",
)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to create a task: {task_request}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -171,12 +166,14 @@ async def list_agent_tasks(
media_type="application/json",
)
except NotFoundError:
LOG.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Tasks not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list tasks")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -246,12 +243,14 @@ async def get_agent_task(request: Request, task_id: str) -> Task:
media_type="application/json",
)
except NotFoundError:
LOG.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to get task: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -311,12 +310,14 @@ async def list_agent_task_steps(
media_type="application/json",
)
except NotFoundError:
LOG.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Steps not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list steps")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -377,13 +378,14 @@ async def execute_agent_task_step(
media_type="application/json",
)
except NotFoundError:
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": f"Task not found {task_id}"}),
status_code=404,
media_type="application/json",
)
except Exception as e:
LOG.exception("Error whilst trying to execute a test")
LOG.exception(f"Error whilst trying to execute a task step: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -423,12 +425,14 @@ async def get_agent_task_step(request: Request, task_id: str, step_id: str) -> S
step = await agent.get_step(task_id, step_id)
return Response(content=step.json(), status_code=200)
except NotFoundError:
LOG.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Step not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to get step: {step_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -484,12 +488,14 @@ async def list_agent_task_artifacts(
artifacts = await agent.list_artifacts(task_id, page, page_size)
return artifacts
except NotFoundError:
LOG.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Task not found"}),
content=json.dumps({"error": "Artifacts not found for task_id"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception("Error whilst trying to list artifacts")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand All @@ -502,7 +508,7 @@ async def list_agent_task_artifacts(
)
@tracing("Uploading task artifact")
async def upload_agent_task_artifacts(
request: Request, task_id: str, artifact_upload: ArtifactUpload
request: Request, task_id: str, file: UploadFile, relative_path: str
) -> Artifact:
"""
Uploads an artifact for a specific task using a provided file.
Expand All @@ -529,26 +535,22 @@ async def upload_agent_task_artifacts(
}
"""
agent = request["agent"]
if artifact_upload.file is None:

if file is None:
return Response(
content=json.dumps({"error": "File must be specified"}),
status_code=404,
media_type="application/json",
)
try:
artifact = await agent.create_artifact(task_id, artifact_upload)
artifact = await agent.create_artifact(task_id, file, relative_path)
return Response(
content=artifact.json(),
status_code=200,
media_type="application/json",
)
except NotFoundError:
return Response(
content=json.dumps({"error": "Task not found"}),
status_code=404,
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to upload artifact: {task_id}")
return Response(
content=json.dumps({"error": "Internal server error"}),
status_code=500,
Expand Down Expand Up @@ -585,6 +587,7 @@ async def download_agent_task_artifact(
try:
return await agent.get_artifact(task_id, artifact_id)
except NotFoundError:
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
Expand All @@ -595,6 +598,7 @@ async def download_agent_task_artifact(
media_type="application/json",
)
except Exception:
LOG.exception(f"Error whilst trying to download artifact: {task_id}")
return Response(
content=json.dumps(
{
Expand Down
Loading

0 comments on commit 60dcaa1

Please sign in to comment.