Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Added workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Swiftyos committed Aug 17, 2023
1 parent b7fa5f6 commit 0bd35ce
Show file tree
Hide file tree
Showing 7 changed files with 509 additions and 17 deletions.
30 changes: 23 additions & 7 deletions autogpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import time

import autogpt.utils
from autogpt.agent_protocol import Agent, Artifact, Step, Task, TaskDB

from .workspace import Workspace


class AutoGPT(Agent):
def __init__(self, db: TaskDB) -> None:
def __init__(self, db: TaskDB, workspace: Workspace) -> None:
super().__init__(db)
self.workspace = workspace

async def create_task(self, task: Task) -> None:
print(f"task: {task.input}")
Expand All @@ -32,21 +36,33 @@ async def run_step(self, step: Step) -> Step:
print(f"Step completed: {updated_step}")
return updated_step

async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> bytes:
"""
Retrieve the artifact data from wherever it is stored and return it as bytes.
"""
return artifact
if not artifact.uri.startswith("file://"):
raise NotImplementedError("Loading from uri not implemented")
file_path = artifact.uri.split("file://")[1]
if not self.workspace.exists(file_path):
raise FileNotFoundError(f"File {file_path} not found in workspace")
return self.workspace.read(file_path)

async def save_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
async def save_artifact(
self, task_id: str, artifact: Artifact, data: bytes
) -> Artifact:
"""
Save the artifact data to the agent's workspace, loading from uri if bytes are not available.
"""
assert (
artifact.data is not None and artifact.uri is not None
), "Artifact data or uri must be set"
data is not None and artifact.uri is not None
), "Data or Artifact uri must be set"

if artifact.data is None:
if data is not None:
file_path = os.path.join(task_id / artifact.file_name)
self.write(file_path, data)
artifact.uri = f"file://{file_path}"
self.db.save_artifact(task_id, artifact)
else:
raise NotImplementedError("Loading from uri not implemented")

return artifact
15 changes: 8 additions & 7 deletions autogpt/agent_protocol/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def upload_agent_task_artifacts(
agent: Agent = request["agent"]
if not file and not uri:
return Response(status_code=400, content="No file or uri provided")

data = None
if uri:
artifact = Artifact(
task_id=task_id,
Expand All @@ -169,10 +169,9 @@ async def upload_agent_task_artifacts(
artifact = Artifact(
task_id=task_id,
file_name=file_name,
data=contents,
)

artifact = await agent.save_artifact(task_id, artifact)
artifact = await agent.save_artifact(task_id, artifact, data)
agent.db.add_artifact(task_id, artifact)

return artifact
Expand Down Expand Up @@ -227,14 +226,16 @@ async def create_task(self, task: Task):
async def run_step(self, step: Step) -> Step:
return step

async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
async def retrieve_artifact(self, task_id: str, artifact: Artifact) -> bytes:
"""
Retrieve the artifact data from wherever it is stored and return it as bytes.
"""
return artifact
raise NotImplementedError("Retrieve artifact not implemented")

async def save_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
async def save_artifact(
self, task_id: str, artifact: Artifact, data: bytes | None = None
) -> Artifact:
"""
Save the artifact data to the agent's workspace, loading from uri if bytes are not available.
"""
return artifact
raise NotImplementedError("Save artifact not implemented")
1 change: 0 additions & 1 deletion autogpt/benchmark_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from agbenchmark.app import get_artifact, get_skill_tree
from fastapi import APIRouter
from fastapi import (
Expand Down
127 changes: 127 additions & 0 deletions autogpt/workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import abc
import os
import typing
from pathlib import Path

from google.cloud import storage


class Workspace(abc.ABC):
@abc.abstractclassmethod
def __init__(self, base_path: str) -> None:
self.base_path = base_path

@abc.abstractclassmethod
def read(self, path: str) -> bytes:
pass

@abc.abstractclassmethod
def write(self, path: str, data: bytes) -> None:
pass

@abc.abstractclassmethod
def delete(
self, path: str, directory: bool = False, recursive: bool = False
) -> None:
pass

@abc.abstractclassmethod
def exists(self, path: str) -> bool:
pass

@abc.abstractclassmethod
def list(self, path: str) -> typing.List[str]:
pass


class LocalWorkspace(Workspace):
def __init__(self, base_path: str):
self.base_path = Path(base_path).resolve()

def _resolve_path(self, path: str) -> Path:
abs_path = (self.base_path / path).resolve()
if not str(abs_path).startswith(str(self.base_path)):
raise ValueError("Directory traversal is not allowed!")
return abs_path

def read(self, path: str) -> bytes:
path = self.base_path / path
with open(self._resolve_path(path), "rb") as f:
return f.read()

def write(self, path: str, data: bytes) -> None:
path = self.base_path / path
with open(self._resolve_path(path), "wb") as f:
f.write(data)

def delete(
self, path: str, directory: bool = False, recursive: bool = False
) -> None:
path = self.base_path / path
resolved_path = self._resolve_path(path)
if directory:
if recursive:
os.rmdir(resolved_path)
else:
os.removedirs(resolved_path)
else:
os.remove(resolved_path)

def exists(self, path: str) -> bool:
path = self.base_path / path
return self._resolve_path(path).exists()

def list(self, path: str) -> typing.List[str]:
path = self.base_path / path
base = self._resolve_path(path)
return [str(p.relative_to(self.base_path)) for p in base.iterdir()]


class GCSWorkspace(Workspace):
def __init__(self, base_path: str, bucket_name: str):
self.client = storage.Client()
self.bucket_name = bucket_name
self.base_path = base_path.strip("/") # Ensure no trailing or leading slash

def _resolve_path(self, path: str) -> str:
resolved = os.path.join(self.base_path, path).strip("/")
if not resolved.startswith(self.base_path):
raise ValueError("Directory traversal is not allowed!")
return resolved

def read(self, path: str) -> bytes:
path = self.base_path / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(self._resolve_path(path))
return blob.download_as_bytes()

def write(self, path: str, data: bytes) -> None:
path = self.base_path / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(self._resolve_path(path))
blob.upload_from_string(data)

def delete(
self, path: str, directory: bool = False, recursive: bool = False
) -> None:
path = self.base_path / path
bucket = self.client.get_bucket(self.bucket_name)
if directory and recursive:
# Note: GCS doesn't really have directories, so this will just delete all blobs with the given prefix
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
bucket.delete_blobs(blobs)
else:
blob = bucket.blob(self._resolve_path(path))
blob.delete()

def exists(self, path: str) -> bool:
path = self.base_path / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(self._resolve_path(path))
return blob.exists()

def list(self, path: str) -> typing.List[str]:
path = self.base_path / path
bucket = self.client.get_bucket(self.bucket_name)
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
return [blob.name for blob in blobs]
46 changes: 46 additions & 0 deletions autogpt/workspace_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os

import pytest

# Assuming the classes are defined in a file named workspace.py
from .workspace import LocalWorkspace

# Constants
TEST_BASE_PATH = "/tmp/test_workspace"
TEST_FILE_CONTENT = b"Hello World"


# Setup and Teardown for LocalWorkspace


@pytest.fixture
def setup_local_workspace():
os.makedirs(TEST_BASE_PATH, exist_ok=True)
yield
os.system(f"rm -rf {TEST_BASE_PATH}") # Cleanup after tests


def test_local_read_write_delete_exists(setup_local_workspace):
workspace = LocalWorkspace(TEST_BASE_PATH)

# Write
workspace.write("test_file.txt", TEST_FILE_CONTENT)

# Exists
assert workspace.exists("test_file.txt")

# Read
assert workspace.read("test_file.txt") == TEST_FILE_CONTENT

# Delete
workspace.delete("test_file.txt")
assert not workspace.exists("test_file.txt")


def test_local_list(setup_local_workspace):
workspace = LocalWorkspace(TEST_BASE_PATH)
workspace.write("test1.txt", TEST_FILE_CONTENT)
workspace.write("test2.txt", TEST_FILE_CONTENT)

files = workspace.list(".")
assert set(files) == set(["test1.txt", "test2.txt"])
Loading

0 comments on commit 0bd35ce

Please sign in to comment.