forked from Significant-Gravitas/AutoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Forge/workshop (Significant-Gravitas#5654)
* Added basic memory * Added action history * Deleted placeholder files * adding memstore * Added web search ability * Added web search and reading web pages * remove agent.py changes Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com> --------- Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com> Co-authored-by: SwiftyOS <craigswift13@gmail.com>
- Loading branch information
1 parent
f77d383
commit 3bd8ae4
Showing
17 changed files
with
1,358 additions
and
381 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
import os | ||
|
||
from forge.agent import ForgeAgent | ||
from forge.sdk import AgentDB, LocalWorkspace | ||
from forge.sdk import LocalWorkspace | ||
from .db import ForgeDatabase | ||
|
||
database_name = os.getenv("DATABASE_STRING") | ||
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE")) | ||
database = AgentDB(database_name, debug_enabled=False) | ||
database = ForgeDatabase(database_name, debug_enabled=False) | ||
agent = ForgeAgent(database=database, workspace=workspace) | ||
|
||
app = agent.get_agent_app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from .sdk import AgentDB, ForgeLogger, NotFoundError, Base | ||
from sqlalchemy.exc import SQLAlchemyError | ||
|
||
import datetime | ||
from sqlalchemy import ( | ||
Column, | ||
DateTime, | ||
String, | ||
) | ||
import uuid | ||
|
||
LOG = ForgeLogger(__name__) | ||
|
||
class ChatModel(Base): | ||
__tablename__ = "chat" | ||
msg_id = Column(String, primary_key=True, index=True) | ||
task_id = Column(String) | ||
role = Column(String) | ||
content = Column(String) | ||
created_at = Column(DateTime, default=datetime.datetime.utcnow) | ||
modified_at = Column( | ||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow | ||
) | ||
|
||
class ActionModel(Base): | ||
__tablename__ = "action" | ||
action_id = Column(String, primary_key=True, index=True) | ||
task_id = Column(String) | ||
name = Column(String) | ||
args = Column(String) | ||
created_at = Column(DateTime, default=datetime.datetime.utcnow) | ||
modified_at = Column( | ||
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow | ||
) | ||
|
||
|
||
class ForgeDatabase(AgentDB): | ||
|
||
async def add_chat_history(self, task_id, messages): | ||
for message in messages: | ||
await self.add_chat_message(task_id, message['role'], message['content']) | ||
|
||
async def add_chat_message(self, task_id, role, content): | ||
if self.debug_enabled: | ||
LOG.debug("Creating new task") | ||
try: | ||
with self.Session() as session: | ||
mew_msg = ChatModel( | ||
msg_id=str(uuid.uuid4()), | ||
task_id=task_id, | ||
role=role, | ||
content=content, | ||
) | ||
session.add(mew_msg) | ||
session.commit() | ||
session.refresh(mew_msg) | ||
if self.debug_enabled: | ||
LOG.debug(f"Created new Chat message with task_id: {mew_msg.msg_id}") | ||
return mew_msg | ||
except SQLAlchemyError as e: | ||
LOG.error(f"SQLAlchemy error while creating task: {e}") | ||
raise | ||
except NotFoundError as e: | ||
raise | ||
except Exception as e: | ||
LOG.error(f"Unexpected error while creating task: {e}") | ||
raise | ||
|
||
async def get_chat_history(self, task_id): | ||
if self.debug_enabled: | ||
LOG.debug(f"Getting chat history with task_id: {task_id}") | ||
try: | ||
with self.Session() as session: | ||
if messages := ( | ||
session.query(ChatModel) | ||
.filter(ChatModel.task_id == task_id) | ||
.order_by(ChatModel.created_at) | ||
.all() | ||
): | ||
return [{"role": m.role, "content": m.content} for m in messages] | ||
|
||
else: | ||
LOG.error( | ||
f"Chat history not found with task_id: {task_id}" | ||
) | ||
raise NotFoundError("Chat history not found") | ||
except SQLAlchemyError as e: | ||
LOG.error(f"SQLAlchemy error while getting chat history: {e}") | ||
raise | ||
except NotFoundError as e: | ||
raise | ||
except Exception as e: | ||
LOG.error(f"Unexpected error while getting chat history: {e}") | ||
raise | ||
|
||
async def create_action(self, task_id, name, args): | ||
try: | ||
with self.Session() as session: | ||
new_action = ActionModel( | ||
action_id=str(uuid.uuid4()), | ||
task_id=task_id, | ||
name=name, | ||
args=str(args), | ||
) | ||
session.add(new_action) | ||
session.commit() | ||
session.refresh(new_action) | ||
if self.debug_enabled: | ||
LOG.debug(f"Created new Action with task_id: {new_action.action_id}") | ||
return new_action | ||
except SQLAlchemyError as e: | ||
LOG.error(f"SQLAlchemy error while creating action: {e}") | ||
raise | ||
except NotFoundError as e: | ||
raise | ||
except Exception as e: | ||
LOG.error(f"Unexpected error while creating action: {e}") | ||
raise | ||
|
||
async def get_action_history(self, task_id): | ||
if self.debug_enabled: | ||
LOG.debug(f"Getting action history with task_id: {task_id}") | ||
try: | ||
with self.Session() as session: | ||
if actions := ( | ||
session.query(ActionModel) | ||
.filter(ActionModel.task_id == task_id) | ||
.order_by(ActionModel.created_at) | ||
.all() | ||
): | ||
return [{"name": a.name, "args": a.args} for a in actions] | ||
|
||
else: | ||
LOG.error( | ||
f"Action history not found with task_id: {task_id}" | ||
) | ||
raise NotFoundError("Action history not found") | ||
except SQLAlchemyError as e: | ||
LOG.error(f"SQLAlchemy error while getting action history: {e}") | ||
raise | ||
except NotFoundError as e: | ||
raise | ||
except Exception as e: | ||
LOG.error(f"Unexpected error while getting action history: {e}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
|
||
from __future__ import annotations | ||
|
||
import json | ||
import time | ||
from itertools import islice | ||
|
||
from duckduckgo_search import DDGS | ||
|
||
from ..registry import ability | ||
|
||
DUCKDUCKGO_MAX_ATTEMPTS = 3 | ||
|
||
|
||
@ability( | ||
name="web_search", | ||
description="Searches the web", | ||
parameters=[ | ||
{ | ||
"name": "query", | ||
"description": "The search query", | ||
"type": "string", | ||
"required": True, | ||
} | ||
], | ||
output_type="list[str]", | ||
) | ||
async def web_search(agent, task_id: str, query: str) -> str: | ||
"""Return the results of a Google search | ||
Args: | ||
query (str): The search query. | ||
num_results (int): The number of results to return. | ||
Returns: | ||
str: The results of the search. | ||
""" | ||
search_results = [] | ||
attempts = 0 | ||
num_results = 8 | ||
|
||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS: | ||
if not query: | ||
return json.dumps(search_results) | ||
|
||
results = DDGS().text(query) | ||
search_results = list(islice(results, num_results)) | ||
|
||
if search_results: | ||
break | ||
|
||
time.sleep(1) | ||
attempts += 1 | ||
|
||
results = json.dumps(search_results, ensure_ascii=False, indent=4) | ||
return safe_google_results(results) | ||
|
||
|
||
def safe_google_results(results: str | list) -> str: | ||
""" | ||
Return the results of a Google search in a safe format. | ||
Args: | ||
results (str | list): The search results. | ||
Returns: | ||
str: The results of the search. | ||
""" | ||
if isinstance(results, list): | ||
safe_message = json.dumps( | ||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results] | ||
) | ||
else: | ||
safe_message = results.encode("utf-8", "ignore").decode("utf-8") | ||
return safe_message |
Oops, something went wrong.