Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Updated workspace to include task_id in artifacts path
Browse files Browse the repository at this point in the history
  • Loading branch information
Swiftyos committed Aug 22, 2023
1 parent 22f4dec commit 6f44e1b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 87 deletions.
82 changes: 44 additions & 38 deletions autogpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,18 @@ async def create_and_execute_step(
step.artifacts.append(art)
step.status = "completed"
else:
steps = await self.db.list_steps(task_id)
artifacts = await self.db.list_artifacts(task_id)
steps, steps_pagination = await self.db.list_steps(
task_id, page=1, per_page=100
)
artifacts, artifacts_pagination = await self.db.list_artifacts(
task_id, page=1, per_page=100
)
step = steps[-1]
step.artifacts = artifacts
step.output = "No more steps to run."
step.is_last = True
# The step is the last step on this page so checking if this is the
# last page is sufficent to know if it is the last step
step.is_last = steps_pagination.current == steps_pagination.pages
if isinstance(step.status, Status):
step.status = step.status.value
step.output = "Done some work"
Expand Down Expand Up @@ -218,52 +224,52 @@ async def create_artifact(

return artifact

async def load_from_uri(self, uri: str, artifact_id: str) -> bytes:
async def load_from_uri(self, uri: str, task_id: str) -> bytes:
"""
Load file from given URI and return its bytes.
"""
file_path = None
try:
if uri.startswith("file://"):
file_path = uri.split("file://")[1]
if not self.workspace.exists(file_path):
if not self.workspace.exists(task_id, file_path):
return Response(status_code=500, content="File not found")
return self.workspace.read(file_path)
elif uri.startswith("s3://"):
import boto3
return self.workspace.read(task_id, file_path)
# elif uri.startswith("s3://"):
# import boto3

s3 = boto3.client("s3")
bucket_name, key_name = uri[5:].split("/", 1)
file_path = "/tmp/" + artifact_id
s3.download_file(bucket_name, key_name, file_path)
with open(file_path, "rb") as f:
return f.read()
elif uri.startswith("gs://"):
from google.cloud import storage
# s3 = boto3.client("s3")
# bucket_name, key_name = uri[5:].split("/", 1)
# file_path = "/tmp/" + task_id
# s3.download_file(bucket_name, key_name, file_path)
# with open(file_path, "rb") as f:
# return f.read()
# elif uri.startswith("gs://"):
# from google.cloud import storage

storage_client = storage.Client()
bucket_name, blob_name = uri[5:].split("/", 1)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
file_path = "/tmp/" + artifact_id
blob.download_to_filename(file_path)
with open(file_path, "rb") as f:
return f.read()
elif uri.startswith("https://"):
from azure.storage.blob import BlobServiceClient
# storage_client = storage.Client()
# bucket_name, blob_name = uri[5:].split("/", 1)
# bucket = storage_client.bucket(bucket_name)
# blob = bucket.blob(blob_name)
# file_path = "/tmp/" + task_id
# blob.download_to_filename(file_path)
# with open(file_path, "rb") as f:
# return f.read()
# elif uri.startswith("https://"):
# from azure.storage.blob import BlobServiceClient

blob_service_client = BlobServiceClient.from_connection_string(
"my_connection_string"
)
container_name, blob_name = uri[8:].split("/", 1)
blob_client = blob_service_client.get_blob_client(
container_name, blob_name
)
file_path = "/tmp/" + artifact_id
with open(file_path, "wb") as download_file:
download_file.write(blob_client.download_blob().readall())
with open(file_path, "rb") as f:
return f.read()
# blob_service_client = BlobServiceClient.from_connection_string(
# "my_connection_string"
# )
# container_name, blob_name = uri[8:].split("/", 1)
# blob_client = blob_service_client.get_blob_client(
# container_name, blob_name
# )
# file_path = "/tmp/" + task_id
# with open(file_path, "wb") as download_file:
# download_file.write(blob_client.download_blob().readall())
# with open(file_path, "rb") as f:
# return f.read()
else:
return Response(status_code=500, content="Loading from unsupported uri")
except Exception as e:
Expand Down
6 changes: 6 additions & 0 deletions autogpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ def read_webpage(url: str) -> typing.Optional[str]:
except Exception as e:
print(f"Unable to read webpage: {e}")
return contents


if __name__ == "__main__":
test_messages = [{"role": "user", "content": "Hello, how are you?"}]
response = chat_completion_request(test_messages)
print(response)
83 changes: 42 additions & 41 deletions autogpt/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,54 @@ def __init__(self, base_path: str) -> None:
self.base_path = base_path

@abc.abstractclassmethod
def read(self, path: str) -> bytes:
def read(self, task_id: str, path: str) -> bytes:
pass

@abc.abstractclassmethod
def write(self, path: str, data: bytes) -> None:
def write(self, task_id: str, path: str, data: bytes) -> None:
pass

@abc.abstractclassmethod
def delete(
self, path: str, directory: bool = False, recursive: bool = False
self, task_id: str, path: str, directory: bool = False, recursive: bool = False
) -> None:
pass

@abc.abstractclassmethod
def exists(self, path: str) -> bool:
def exists(self, task_id: str, path: str) -> bool:
pass

@abc.abstractclassmethod
def list(self, path: str) -> typing.List[str]:
def list(self, task_id: str, path: str) -> typing.List[str]:
pass


class LocalWorkspace(Workspace):
def __init__(self, base_path: str):
self.base_path = Path(base_path).resolve()

def _resolve_path(self, path: str) -> Path:
abs_path = (self.base_path / path).resolve()
def _resolve_path(self, task_id: str, path: str) -> Path:
abs_path = (self.base_path / task_id / path).resolve()
if not str(abs_path).startswith(str(self.base_path)):
raise ValueError("Directory traversal is not allowed!")
(self.base_path / task_id).mkdir(parents=True, exist_ok=True)
return abs_path

def read(self, path: str) -> bytes:
path = self.base_path / path
with open(self._resolve_path(path), "rb") as f:
def read(self, task_id: str, path: str) -> bytes:
path = self.base_path / task_id / path
with open(self._resolve_path(task_id, path), "rb") as f:
return f.read()

def write(self, path: str, data: bytes) -> None:
path = self.base_path / path
with open(self._resolve_path(path), "wb") as f:
def write(self, task_id: str, path: str, data: bytes) -> None:
path = self.base_path / task_id / path
with open(self._resolve_path(task_id, path), "wb") as f:
f.write(data)

def delete(
self, path: str, directory: bool = False, recursive: bool = False
self, task_id: str, path: str, directory: bool = False, recursive: bool = False
) -> None:
path = self.base_path / path
resolved_path = self._resolve_path(path)
path = self.base_path / task_id / path
resolved_path = self._resolve_path(task_id, path)
if directory:
if recursive:
os.rmdir(resolved_path)
Expand All @@ -67,14 +68,14 @@ def delete(
else:
os.remove(resolved_path)

def exists(self, path: str) -> bool:
path = self.base_path / path
return self._resolve_path(path).exists()
def exists(self, task_id: str, path: str) -> bool:
path = self.base_path / task_id / path
return self._resolve_path(task_id, path).exists()

def list(self, path: str) -> typing.List[str]:
path = self.base_path / path
base = self._resolve_path(path)
return [str(p.relative_to(self.base_path)) for p in base.iterdir()]
def list(self, task_id: str, path: str) -> typing.List[str]:
path = self.base_path / task_id / path
base = self._resolve_path(task_id, path)
return [str(p.relative_to(self.base_path / task_id)) for p in base.iterdir()]


class GCSWorkspace(Workspace):
Expand All @@ -83,45 +84,45 @@ def __init__(self, base_path: str, bucket_name: str):
self.bucket_name = bucket_name
self.base_path = base_path.strip("/") # Ensure no trailing or leading slash

def _resolve_path(self, path: str) -> str:
resolved = os.path.join(self.base_path, path).strip("/")
def _resolve_path(self, task_id: str, path: str) -> str:
resolved = os.path.join(self.base_path, task_id, path).strip("/")
if not resolved.startswith(self.base_path):
raise ValueError("Directory traversal is not allowed!")
return resolved

def read(self, path: str) -> bytes:
path = self.base_path / path
def read(self, task_id: str, path: str) -> bytes:
path = self.base_path / task_id / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(self._resolve_path(path))
blob = bucket.get_blob(self._resolve_path(task_id, path))
return blob.download_as_bytes()

def write(self, path: str, data: bytes) -> None:
path = self.base_path / path
def write(self, task_id: str, path: str, data: bytes) -> None:
path = self.base_path / task_id / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(self._resolve_path(path))
blob = bucket.blob(self._resolve_path(task_id, path))
blob.upload_from_string(data)

def delete(
self, path: str, directory: bool = False, recursive: bool = False
self, task_id: str, path: str, directory: bool = False, recursive: bool = False
) -> None:
path = self.base_path / path
path = self.base_path / task_id / path
bucket = self.client.get_bucket(self.bucket_name)
if directory and recursive:
# Note: GCS doesn't really have directories, so this will just delete all blobs with the given prefix
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
blobs = bucket.list_blobs(prefix=self._resolve_path(task_id, path))
bucket.delete_blobs(blobs)
else:
blob = bucket.blob(self._resolve_path(path))
blob = bucket.blob(self._resolve_path(task_id, path))
blob.delete()

def exists(self, path: str) -> bool:
path = self.base_path / path
def exists(self, task_id: str, path: str) -> bool:
path = self.base_path / task_id / path
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(self._resolve_path(path))
blob = bucket.blob(self._resolve_path(task_id, path))
return blob.exists()

def list(self, path: str) -> typing.List[str]:
path = self.base_path / path
def list(self, task_id: str, path: str) -> typing.List[str]:
path = self.base_path / task_id / path
bucket = self.client.get_bucket(self.bucket_name)
blobs = bucket.list_blobs(prefix=self._resolve_path(path))
blobs = bucket.list_blobs(prefix=self._resolve_path(task_id, path))
return [blob.name for blob in blobs]
17 changes: 9 additions & 8 deletions autogpt/workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Constants
TEST_BASE_PATH = "/tmp/test_workspace"
TEST_FILE_CONTENT = b"Hello World"
TEST_TASK_ID = "1234"


# Setup and Teardown for LocalWorkspace
Expand All @@ -24,23 +25,23 @@ def test_local_read_write_delete_exists(setup_local_workspace):
workspace = LocalWorkspace(TEST_BASE_PATH)

# Write
workspace.write("test_file.txt", TEST_FILE_CONTENT)
workspace.write(TEST_TASK_ID, "test_file.txt", TEST_FILE_CONTENT)

# Exists
assert workspace.exists("test_file.txt")
assert workspace.exists(TEST_TASK_ID, "test_file.txt")

# Read
assert workspace.read("test_file.txt") == TEST_FILE_CONTENT
assert workspace.read(TEST_TASK_ID, "test_file.txt") == TEST_FILE_CONTENT

# Delete
workspace.delete("test_file.txt")
assert not workspace.exists("test_file.txt")
workspace.delete(TEST_TASK_ID, "test_file.txt")
assert not workspace.exists(TEST_TASK_ID, "test_file.txt")


def test_local_list(setup_local_workspace):
workspace = LocalWorkspace(TEST_BASE_PATH)
workspace.write("test1.txt", TEST_FILE_CONTENT)
workspace.write("test2.txt", TEST_FILE_CONTENT)
workspace.write(TEST_TASK_ID, "test1.txt", TEST_FILE_CONTENT)
workspace.write(TEST_TASK_ID, "test2.txt", TEST_FILE_CONTENT)

files = workspace.list(".")
files = workspace.list(TEST_TASK_ID, ".")
assert set(files) == {"test1.txt", "test2.txt"}

0 comments on commit 6f44e1b

Please sign in to comment.